import torch
import torch.nn as nn

from . import register_component, get_activation
from .utils import CONV_TYPES

@register_component("FourierLayer")
class FourierLayer(nn.Module):
    """
    Fourier layer.

    Args:
        dimension: Dimension for convolution operations (1, 2, or 3)
        in_channels: Number of input channels
        out_channels: Number of output channels
        modes: Number of Fourier modes ## TODO: For dimension==2, allow modes to be a tuple (modes_x, modes_y). Same for dimension==3.
        bias: Whether to include bias in convolutions
        activation: Activation function
        **kwargs: Additional keyword arguments
    """
    def __init__(
            self, 
            dimension: int,
            in_channels: int, 
            out_channels: int,
            modes: int = 32, ## TODO: For dimension==2, allow modes to be a tuple (modes_x, modes_y). Same for dimension==3.
            bias: bool = True,
            activation: str = "gelu",
            **kwargs
        ):
        super().__init__()

        assert dimension in SPECTRAL_CONV_TYPES and dimension in CONV_TYPES, "Dimension must be 1, 2, or 3"
        
        SpectralConv = SPECTRAL_CONV_TYPES[dimension]
        Conv = CONV_TYPES[dimension]

        self.spectral_conv = SpectralConv(
            in_channels, 
            out_channels, 
            modes
        )
        
        self.pointwise_conv = Conv(
            in_channels, 
            out_channels, 
            kernel_size=1, 
            bias=bias
        )

        self.activation = get_activation(activation, **kwargs)

    def forward(self, x):
        x1 = self.spectral_conv(x)
        x2 = self.pointwise_conv(x)
        x = x1 + x2
        x = self.activation(x)
        return x

## TODO: Can we use the same class for all dimensions?
class SpectralConv1d(nn.Module):
    """
    Spectral convolution in 1D.
    
    Args:
        in_channels: Number of input channels
        out_channels: Number of output channels
        modes: Number of Fourier modes
    """
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            modes: int
        ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                2, 
                dtype=torch.float32,
            )
        )
        
    def _batchmul1d(self, input, weights):
        return torch.einsum("bix, iox -> box", input, weights)
    
    def forward(self, x):
        batchsize = x.shape[0]

        x_fft = torch.fft.rfft(x)
        
        out_fft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device
        )
        out_fft[:, :, :self.modes] = self._batchmul1d(
            x_fft[:, :, :self.modes], 
            torch.view_as_complex(self.weights)
        )
        
        x = torch.fft.irfft(out_fft, n=x.size(-1))
        return x

class SpectralConv2d(nn.Module):
    """
    Spectral convolution in 2D.
    """
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            modes: int
        ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes,
                2,
                dtype=torch.float32
            )
        )

        self.weights2 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes, 
                2, 
                dtype=torch.float32
            )
        )

    def _batchmul2d(self, input, weights):
        return torch.einsum("bixy, ioxy -> boxy", input, weights)
    
    def forward(self, x):
        batchsize = x.shape[0]

        x_fft = torch.fft.rfft2(x)

        out_fft = torch.zeros(
            batchsize, 
            self.out_channels, 
            x.size(-2), 
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device
        )
        
        out_fft[:, :, :self.modes, :self.modes] = self._batchmul2d(
            x_fft[:, :, :self.modes, :self.modes], 
            torch.view_as_complex(self.weights1)
        )
        
        out_fft[:, :, -self.modes:, :self.modes] = self._batchmul2d(
            x_fft[:, :, -self.modes:, :self.modes], 
            torch.view_as_complex(self.weights2)
        )
        
        x = torch.fft.irfft2(out_fft, s=(x.size(-2), x.size(-1)))
        return x

class SpectralConv3d(nn.Module):
    """
    Spectral convolution in 3D.
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            modes: int
        ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes, 
                modes, 
                2, 
                dtype=torch.float32
            )
        )

        self.weights2 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes, 
                modes, 
                2, 
                dtype=torch.float32
            )
        )

        self.weights3 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes, 
                modes, 
                2, 
                dtype=torch.float32
            )
        )

        self.weights4 = nn.Parameter(
            scale * torch.rand(
                in_channels, 
                out_channels, 
                modes, 
                modes,
                modes,
                2,
                dtype=torch.float32
            )
        )

    def _batchmul3d(self, input, weights):
        return torch.einsum("bixyz, ioxyz -> boxyz", input, weights)
    
    def forward(self, x):
        batchsize = x.shape[0]

        x_fft = torch.fft.rfft3(x)

        out_fft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-3),
            x.size(-2),
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device
        )
        
        out_fft[:, :, :self.modes, :self.modes, :self.modes] = self._batchmul3d(
            x_fft[:, :, :self.modes, :self.modes, :self.modes], 
            torch.view_as_complex(self.weights1)
        )
        
        out_fft[:, :, -self.modes:, :self.modes, :self.modes] = self._batchmul3d(
            x_fft[:, :, -self.modes:, :self.modes, :self.modes], 
            torch.view_as_complex(self.weights2)
        )
        
        out_fft[:, :, :self.modes, -self.modes:, :self.modes] = self._batchmul3d(
            x_fft[:, :, :self.modes, -self.modes:, :self.modes], 
            torch.view_as_complex(self.weights3)
        )
        
        out_fft[:, :, -self.modes:, -self.modes:, :self.modes] = self._batchmul3d(
            x_fft[:, :, -self.modes:, -self.modes:, :self.modes], 
            torch.view_as_complex(self.weights4)
        )
        
        x = torch.fft.irfft3(out_fft, s=(x.size(-3), x.size(-2), x.size(-1)))
        return x

SPECTRAL_CONV_TYPES = {
    1: SpectralConv1d,
    2: SpectralConv2d,
    3: SpectralConv3d,
}