import torch
import torch.nn as nn


class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, groups=1):
        super().__init__()
        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.scale = 1 / (in_channels * out_channels)
        self.groups = groups
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, env, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, env, out_channel, x,y)
        return torch.einsum("beixy,eioxy->beoxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)
        x_ft = x_ft.reshape(batchsize, self.groups, self.in_channels, x.size(-2), x.size(-1) // 2 + 1)
        # Multiply relevant Fourier modes
        weights1 = self.weights1.reshape(self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        weights2 = self.weights2.reshape(self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        out_ft = torch.zeros(batchsize, self.groups, self.out_channels, x.size(-2), x.size(-1) // 2 + 1,
                             dtype=torch.cfloat,
                             device=x.device)
        out_ft[:, :, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :, :self.modes1, :self.modes2], torch.view_as_complex(weights1))
        out_ft[:, :, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :, -self.modes1:, :self.modes2], torch.view_as_complex(weights2))
        # Return to physical space
        out_ft = out_ft.reshape(batchsize, self.groups * self.out_channels, x.size(-2), x.size(-1) // 2 + 1)
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class SpectralConv2dManualWeight(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, groups=1):
        super().__init__()
        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        NOTE: the difference is that in here weights are passed as arguments to the forward function
        Also, batching should be implemented via vmap
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.scale = 1 / (in_channels * out_channels)
        self.groups = groups
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # NOTE: this considers the batch size already!
        # (batch, env, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, env, out_channel, x,y)
        return torch.einsum("beixy,beioxy->beoxy", input, weights)

    def forward(self, x, weights1, weights2):
        """Consider batch size in all the multiplication"""
        bs = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)
        x_ft = x_ft.reshape(bs, self.groups, self.in_channels, x.size(-2), x.size(-1) // 2 + 1)
        # Update weights
        weights1 = self.weights1 + weights1
        weights2 = self.weights2 + weights2
        # Multiply relevant Fourier modes
        weights1 = weights1.reshape(bs, self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        weights2 = weights2.reshape(bs, self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        out_ft = torch.zeros(bs, self.groups, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat,
                             device=x.device)
        out_ft[:, :, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :, :self.modes1, :self.modes2], torch.view_as_complex(weights1))
        out_ft[:, :, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :, -self.modes1:, :self.modes2], torch.view_as_complex(weights2))
        # Return to physical space
        out_ft = out_ft.reshape(bs, self.groups * self.out_channels, x.size(-2), x.size(-1) // 2 + 1)
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class HyperSpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, ctx_dim, groups=1):
        super().__init__()
        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2
        self.scale = 1 / (in_channels * out_channels)
        self.groups = groups
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(groups * in_channels, out_channels, self.modes1, self.modes2, 2))

        self.hyper_weight1 = nn.Linear(ctx_dim, self.weights1.numel(), bias=False)
        self.hyper_weight2 = nn.Linear(ctx_dim, self.weights2.numel(), bias=False)

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, env, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, env, out_channel, x,y)
        return torch.einsum("beixy,beioxy->beoxy", input, weights)

    def forward(self, x, c):
        """
        c: [batch, ctx_dim]
        """
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)
        x_ft = x_ft.reshape(batchsize, self.groups, self.in_channels, x.size(-2), x.size(-1) // 2 + 1)
        # Multiply relevant Fourier modes
        weights1 = self.weights1.reshape(self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        hyper_weights1 = self.hyper_weight1(c).reshape(batchsize, self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        weights1 = weights1 + hyper_weights1

        weights2 = self.weights2.reshape(self.groups, self.in_channels, self.out_channels, self.modes1, self.modes2, 2)
        hyper_weights2 = self.hyper_weight2(c).reshape(batchsize, self.groups, self.in_channels, self.out_channels, self.modes1,
                                                       self.modes2, 2)
        weights2 = weights2 + hyper_weights2


        out_ft = torch.zeros(batchsize, self.groups, self.out_channels, x.size(-2), x.size(-1) // 2 + 1,
                             dtype=torch.cfloat,
                             device=x.device)
        
        out_ft[..., :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[..., :self.modes1, :self.modes2], torch.view_as_complex(weights1))
        out_ft[..., -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[..., -self.modes1:, :self.modes2], torch.view_as_complex(weights2))
        # Return to physical space
        out_ft = out_ft.reshape(batchsize, self.groups * self.out_channels, x.size(-2), x.size(-1) // 2 + 1)
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x
