import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.fno import FNO2d, SpectralConv2d


"""
A 2-layer parameter embedding module for 2D data.
"""
class PrmEmb_Block_2d(nn.Module):
    def __init__(self, widening_factor, num_params, num_channels, modes1, modes2, kernel_size=5, normed_dim=[64,64]):
        super().__init__()
        self.num_params = num_params

        self.CNN_0 = nn.Conv2d(num_channels, widening_factor//2,
                               kernel_size=kernel_size, bias=False, padding='same')
        self.avpool = nn.AvgPool2d(4, 4)
        self.CNN_SC = SpectralConv2d(in_channels=widening_factor//2, out_channels=widening_factor,
                                     modes1=modes1, modes2=modes2, spatial_size=64)
        self.CNN_AD = nn.Conv2d(widening_factor//2, widening_factor,
                                kernel_size=kernel_size, bias=False, padding='same',
                                groups=widening_factor//2)
        self.CNN_1 = nn.Conv2d(widening_factor//2, widening_factor, kernel_size=1)
        self.CNN_AC = nn.Conv2d(widening_factor//2, widening_factor, kernel_size=1)
        self.CNN_2 = nn.Conv2d(widening_factor, num_channels, kernel_size=1)
        self.widening_factor = widening_factor

        self.fc0a = nn.Linear(self.num_params, self.widening_factor//2)
        self.fc0b = nn.Linear(self.widening_factor//2, self.widening_factor)
        self.fc1a = nn.Linear(self.num_params, self.widening_factor//2)
        self.fc1b = nn.Linear(self.widening_factor//2, self.widening_factor)
        self.fc2a = nn.Linear(self.num_params, self.widening_factor//2)
        self.fc2b = nn.Linear(self.widening_factor//2, self.widening_factor)

        self.LN = nn.LayerNorm(normed_dim)

    def forward(self, inputs):
        x, x_p = inputs
        # parameter embedding: (B, widening_factor)
        inp0 = self.fc0b(F.gelu(self.fc0a(torch.sigmoid(torch.log(x_p)))))
        inp1 = self.fc1b(F.gelu(self.fc1a(torch.sigmoid(torch.log(x_p)))))
        inp2 = self.fc2b(F.gelu(self.fc2a(torch.sigmoid(torch.log(x_p)))))
        # CNN
        y = self.CNN_0(x)
        y = F.gelu(y)
        y0 = self.avpool(y)
        y0 = self.CNN_SC(y0) * inp0[:, :, None, None]
        y0 = F.interpolate(y0,  scale_factor=4)
        y1 = self.CNN_AC(y) * inp1[:, :, None, None]
        y2 = self.CNN_AD(y) * inp2[:, :, None, None]
        y = (y0 + y1 + y2) + self.CNN_1(y)
        y = F.gelu(y)
        y = self.CNN_2(y)

        return x * (1. + self.LN(y))

class CAPE2d(nn.Module):
    def __init__(self, widening_factor, num_params, in_channels, out_channels, width,
                 modes1, modes2, normed_dim=[64, 64]):
        super(CAPE2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.width = width
        self.num_params = num_params

        self.PrmEmb_Pre = PrmEmb_Block_2d(widening_factor, num_params, num_channels=in_channels,
                                          modes1=modes1, modes2=modes2, kernel_size=5, normed_dim=normed_dim)
        self.model = FNO2d(in_channels=2*in_channels, out_channels=2*out_channels,
                           modes1=modes1, modes2=modes2, width=width, spatial_size=64, n_layers=4)
        self.mix_emb = nn.Conv2d(2*out_channels, out_channels, kernel_size=1)

    def forward(self, x, xp):  # x: (B, C_comb, Nx, Ny), xp: (B, num_params)
        # input parameter embedding as additional channel
        y = self.PrmEmb_Pre([x, xp])  # (B, C_comb, Nx, Ny)
        x = torch.cat((x, y), dim=1)  # (B, C_comb*2, Nx, Ny)

        # forward to FNO2d
        x = self.model(x)  # (B, C*2, Nx, Ny)

        # output mixture
        x = self.mix_emb(x)  # (B, C, Nx, Ny)

        return x