from functools import partial
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from models.sandbox.fno_diag import SpectralKernel1dDiag
from models.sandbox.fno_freq_only import SpectralKernel1dFreqOnly
from models.sandbox.fno_hidden_only import SpectralKernel1dHiddenOnly
from models.sandbox.fno_time_param import SpectralKernel1dTimeParam
from models.custom_layers import get_residual_layer
from models.fast_model import fast_input_layer, fast_output_layer

def get_spectral_kernel_1d(spectral_type, in_channels, out_channels, modes1):
    registry = {"full": SpectralKernel1d,
                "diag": SpectralKernel1dDiag,
                "freq_only": SpectralKernel1dFreqOnly,
                "hidden_only": SpectralKernel1dHiddenOnly,
                "identity": lambda *args, **kargs: nn.Identity(),
                "time_param": SpectralKernel1dTimeParam}
    return registry[spectral_type](in_channels, out_channels, modes1)


class SpectralKernel1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        """
        Spectral FNO Kernel. 
        """
        super(SpectralKernel1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
    
    # Complex multiplication
    def forward(self, x):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", x, self.weights1)
    

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, spectral_type):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.kernel = get_spectral_kernel_1d(spectral_type, in_channels, out_channels, modes1)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)
        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize, self.out_channels, 
            x.size(-1)//2 + 1,  
            device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.kernel(
            x_ft[:, :, :self.modes1])

        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x
    

class MLP1d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP1d, self).__init__()
        self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class FNO1d(nn.Module):
    def __init__(
            self, 
            d_input, 
            d_output,
            n_layers=4,
            modes=16, 
            d_model=64, 
            initial_step=1, 
            spectral_type="full",
            residual_type="zero",
            fast = {},
            **kwargs):
        super(FNO1d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """
        # self.initial_step = initial_step

        self.modes1 = modes
        self.d_model = d_model
        self.padding = 2 # pad the domain if input is non-periodic
        self._d_input = d_input+1
        # input channel is 2: (a(x), x)
        if fast.get("use_fast", False): 
            self.p = fast_input_layer(fast["kernel_size"], fast["stride"], self._d_input, self.d_model, n_dim=1)
            self.q = fast_output_layer(fast["kernel_size"], fast["stride"], self.d_model, d_output, n_dim=1, transposed=False)
        else: 
            self.p = nn.Linear(self._d_input, self.d_model) 
            self.q = MLP1d(self.d_model, d_output, self.d_model * 2)

        self.fno_layers = nn.ModuleList()
        self.ws = nn.ModuleList()
        self.mlps = nn.ModuleList()
        self.residual_layers = nn.ModuleList()

        spectral_layer = partial(SpectralConv1d, self.d_model, self.d_model, self.modes1, spectral_type)

        for _ in range(n_layers-1):
            self.fno_layers.append(
                spectral_layer())
            self.ws.append(nn.Conv1d(d_model, d_model, kernel_size=1))
            self.mlps.append(MLP1d(self.d_model, self.d_model, self.d_model))
            self.residual_layers.append(get_residual_layer(residual_type, d_model))

        self.conv = spectral_layer()
        self.w = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.mlp = MLP1d(self.d_model, self.d_model, self.d_model)

        

    def forward(self, x, grid):
        # x dim = [b, x1, t*v]
        if len(x.shape) == 4 and x.shape[-2] == 1:
            squeezed = True
            x = x.squeeze(-2)
        else: 
            squeezed = False
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 2, 1)
        
        # x = F.pad(x, [0, self.padding]) # pad the domain if input is non-periodic
        for layer, w, mlp, res in zip(self.fno_layers, self.ws, self.mlps, self.residual_layers):
            x_ = x
            x1 = layer(x)
            x1 = mlp(x1)
            x2 = w(x)
            x = x1 + x2
            x = F.gelu(x)
            x = x + res(x_)
        
        x1 = self.conv(x)
        x1 = self.mlp(x1)
        x2 = self.w(x)
        x = x1 + x2

        # x = x[..., :-self.padding]
        x = self.q(x)
        x = x.permute(0, 2, 1)
        if squeezed:
            x = x.unsqueeze(-2)
        return x


class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x

class MLP2d(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP2d, self).__init__()
        self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x

class FNO2d(nn.Module):
    def __init__(
            self, 
            modes=12, 
            modes2=12, 
            d_model=20, 
            n_layers=4,
            initial_step=1,
            n_states=1,
            fast = {},
            **kwargs):
        super(FNO2d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
        input shape: (batchsize, x, y, c)
        output: the solution of the next timestep
        output shape: (batchsize, x, y, c)
        """

        self.modes1 = modes
        self.modes2 = modes2
        self.d_model = d_model
        self.padding = 2 # pad the domain if input is non-periodic
        self._d_input = n_states+2
        d_output = n_states
        if fast.get("use_fast", False):
            self.p = fast_input_layer(fast["kernel_size"], fast["stride"], self._d_input, self.d_model, n_dim=2)
            self.q = fast_output_layer(fast["kernel_size"], fast["stride"], self.d_model, d_output, n_dim=2, transposed=False)
        else: 
            self.p = nn.Linear(self._d_input, self.d_model)
            self.q = MLP2d(self.d_model, d_output, self.d_model * 4)
        # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)

        self.fno_layers = nn.ModuleList()
        self.ws = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for _ in range(n_layers-1):
            self.fno_layers.append(
            SpectralConv2d(self.d_model, self.d_model, self.modes1, self.modes2))
            self.ws.append(nn.Conv2d(self.d_model, self.d_model, 1))
            self.mlps.append(MLP2d(self.d_model, self.d_model, self.d_model))
        self.conv = SpectralConv2d(
            self.d_model, self.d_model, self.modes1, self.modes2)
        self.w = nn.Conv2d(self.d_model, self.d_model,1)
        self.mlp = MLP2d(self.d_model, self.d_model, self.d_model)
        self.norm = nn.InstanceNorm2d(self.d_model)
        

    def forward(self, x, grid):
        # x dim = [b, x1, x2, t*v]
        x = torch.cat((x, grid), dim=-1)
        x = self.p(x)
        x = x.permute(0, 3, 1, 2)
        
        # # Pad tensor with boundary condition
        # x = F.pad(x, [0, self.padding, 0, self.padding])

        for layer, w, mlp in zip(self.fno_layers, self.ws, self.mlps):
            x1 = layer(x)
            x1 = mlp(x1)
            x2 = w(x)
            x = x1 + x2
            x = F.gelu(x)
        
        x1 = self.norm(self.conv(self.norm(x)))
        x1 = self.mlp(x1)
        x2 = self.w(x)
        x = x1 + x2

        # x = x[..., :-self.padding, :-self.padding] # Unpad the tensor
        x = self.q(x)
        x = x.permute(0, 2, 3, 1)
        
        return x
    

class SpectralConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, modes3):
        super(SpectralConv3d, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.modes3 = modes3

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1])

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

class FNO3d(nn.Module):
    def __init__(self, num_channels, modes1=8, modes2=8, modes3=8, width=20, initial_step=10):
        super(FNO3d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t). It's a constant function in time, except for the last index.
        input shape: (batchsize, x=64, y=64, t=40, c=13)
        output: the solution of the next 40 timesteps
        output shape: (batchsize, x=64, y=64, t=40, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.padding = 6 # pad the domain if input is non-periodic
        self.fc0 = nn.Linear(initial_step*num_channels+3, self.width)
        # input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y),  x, y, t)

        self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3)
        self.w0 = nn.Conv3d(self.width, self.width, 1)
        self.w1 = nn.Conv3d(self.width, self.width, 1)
        self.w2 = nn.Conv3d(self.width, self.width, 1)
        self.w3 = nn.Conv3d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm3d(self.width)
        self.bn1 = torch.nn.BatchNorm3d(self.width)
        self.bn2 = torch.nn.BatchNorm3d(self.width)
        self.bn3 = torch.nn.BatchNorm3d(self.width)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, num_channels)

    def forward(self, x, grid):
        # x dim = [b, x1, x2, x3, t*v]
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 4, 1, 2, 3)
        
        x = F.pad(x, [0, self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding]
        x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x.unsqueeze(-2)
