
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from models.custom_layers import FeedForward, act_registry
class FSpectralConv2d(nn.Module):
    '''Adapted from https://github.com/alasdairtran/fourierflow'''
    def __init__(self, d_model, modes, forecast_ff = None, backcast_ff = None,
                 fourier_weight = None, factor = 4, ff_weight_norm = False,
                 n_ff_layers = 2, layer_norm = False, use_fork = False, dropout = 0.0, mode = 'full', **kwargs):
        super().__init__()
        self.in_dim = d_model
        self.out_dim = d_model
        self.n_modes = modes
        self.mode = mode
        self.use_fork = use_fork

        self.fourier_weight = fourier_weight
        # Can't use complex type yet. See https://github.com/pytorch/pytorch/issues/59998
        if not self.fourier_weight:
            self.fourier_weight = nn.ParameterList([])
            for _ in range(2): #x and y dimensions
                weight = torch.FloatTensor(d_model, d_model, modes, 2)
                param = nn.Parameter(weight)
                nn.init.xavier_normal_(param)
                self.fourier_weight.append(param)

        # if use_fork:
        #     self.forecast_ff = forecast_ff
        #     if not self.forecast_ff:
        #         self.forecast_ff = FeedForward(
        #             out_dim, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)

        self.backcast_ff = backcast_ff
        if not self.backcast_ff:
            self.backcast_ff = FeedForward(
                d_model, factor, ff_weight_norm = ff_weight_norm, n_layers = n_ff_layers, layer_norm = layer_norm, dropout = dropout)

    def forward(self, x, batch_dt = None):
        # x.shape == [batch_size, grid_size, grid_size, in_dim]
        if self.mode != 'no-fourier':
            x = self.forward_fourier(x)

        b = self.backcast_ff(x)
        # f = self.forecast_ff(x) if self.use_fork else None
        return b, None

    def forward_fourier(self, x):
        x = rearrange(x, 'b x y h -> b h x y')
        # x.shape == [batch_size, in_dim, grid_size, grid_size]

        B, I, M, N = x.shape

        # # # Dimesion Y # # #
        x_fty = torch.fft.rfft(x, dim=-1, norm='ortho')
        # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1]

        out_ft = x_fty.new_zeros(B, I, M, N // 2 + 1)
        # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2]

        if self.mode == 'full':
            out_ft[:, :, :, :self.n_modes] = torch.einsum(
                "bixy,ioy->boxy",
                x_fty[:, :, :, :self.n_modes],
                torch.view_as_complex(self.fourier_weight[0]))
        elif self.mode == 'low-pass':
            out_ft[:, :, :, :self.n_modes] = x_fty[:, :, :, :self.n_modes]

        xy = torch.fft.irfft(out_ft, n=N, dim=-1, norm='ortho')
        # x.shape == [batch_size, in_dim, grid_size, grid_size]

        # # # Dimesion X # # #
        x_ftx = torch.fft.rfft(x, dim=-2, norm='ortho')
        # x_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size]

        out_ft = x_ftx.new_zeros(B, I, M // 2 + 1, N)
        # out_ft.shape == [batch_size, in_dim, grid_size // 2 + 1, grid_size, 2]

        if self.mode == 'full':
            out_ft[:, :, :self.n_modes, :] = torch.einsum(
                "bixy,iox->boxy",
                x_ftx[:, :, :self.n_modes, :],
                torch.view_as_complex(self.fourier_weight[1]))
        elif self.mode == 'low-pass':
            out_ft[:, :, :self.n_modes, :] = x_ftx[:, :, :self.n_modes, :]

        xx = torch.fft.irfft(out_ft, n=M, dim=-2, norm='ortho')
        # x.shape == [batch_size, in_dim, grid_size, grid_size]

        # # Combining Dimensions # #
        x = xx + xy

        x = rearrange(x, 'b i m n -> b m n i')
        # x.shape == [batch_size, grid_size, grid_size, out_dim]

        return x
    
class FSpectralConv1d(nn.Module):
    '''Adapted from https://github.com/alasdairtran/fourierflow'''
    def __init__(self, d_model, modes, forecast_ff = None, backcast_ff = None,
                 fourier_weight = None, factor = 4, ff_weight_norm = False,
                 n_ff_layers = 2, layer_norm = False, use_fork = False, dropout = 0.0, mode = 'full', activation = "identity",
                  fft_norm = "ortho", **kwargs):
        super().__init__()
        self.in_dim = d_model
        self.out_dim = d_model
        self.n_modes = modes
        self.mode = mode
        self.use_fork = use_fork
        self.fft_norm = fft_norm

        self.fourier_weight = fourier_weight
        # Can't use complex type yet. See https://github.com/pytorch/pytorch/issues/59998
        if not self.fourier_weight:
            self.fourier_weight = nn.ParameterList([])
            for _ in range(1): # only x dimension
                weight = torch.FloatTensor(d_model, d_model, modes, 2)
                param = nn.Parameter(weight)
                nn.init.xavier_normal_(param)
                self.fourier_weight.append(param)

        # if use_fork:
        #     self.forecast_ff = forecast_ff
        #     if not self.forecast_ff:
        #         self.forecast_ff = FeedForward(
        #             out_dim, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)

        self.backcast_ff = backcast_ff
        if not self.backcast_ff:
            self.backcast_ff = FeedForward(
                d_model, factor, ff_weight_norm = ff_weight_norm, n_layers = n_ff_layers, layer_norm = layer_norm, dropout = dropout)
        
        self.act = act_registry[activation]

    def forward(self, x, batch_dt = None):
        # x.shape == [batch_size, grid_size, grid_size, in_dim]
        if self.mode != 'no-fourier':
            x = self.forward_fourier(x)

        b = self.backcast_ff(x)
        # f = self.forecast_ff(x) if self.use_fork else None

        # return b, f
        b = self.act(b)
        # framework to pass out, state
        return b, None

    def forward_fourier(self, x):
        x = rearrange(x, 'b x h -> b h x')
        # x.shape == [batch_size, in_dim, grid_size, grid_size]

        B, H, Sx = x.shape

        # # # Dimesion X # # #
        x_ft = torch.fft.rfft(x, dim=-1, norm=self.fft_norm)
        # x_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1]

        out_ft = x_ft.new_zeros(B, H, Sx // 2 + 1)
        # out_ft.shape == [batch_size, in_dim, grid_size, grid_size // 2 + 1, 2]

        if self.mode == 'full':
            out_ft[:, :, :self.n_modes] = torch.einsum(
                "bix,iox->box",
                x_ft[:, :, :self.n_modes],
                torch.view_as_complex(self.fourier_weight[0]))
                # self.fourier_weight[0])
        elif self.mode == 'low-pass':
            out_ft[:, :, :self.n_modes] = x_ft[:, :, :self.n_modes]
        else: 
            raise ValueError(f"Mode {self.mode} not recognized")

        out = torch.fft.irfft(out_ft, n=Sx, dim=-1, norm=self.fft_norm)
        # x.shape == [batch_size, in_dim, grid_size, grid_size]

        out = rearrange(out, 'b h x -> b x h')
        # x.shape == [batch_size, grid_size, grid_size, out_dim]

        return out


# class FNOFactorized2DBlock(nn.Module):
#     '''Adapted from https://github.com/alasdairtran/fourierflow'''
#     def __init__(self, d_model, modes,
#                   dropout=0.0, in_dropout=0.0,
#                  n_layers=1, share_weight: bool = False,
#                  share_fork=False, factor=2,
#                  ff_weight_norm=False, n_ff_layers=2,
#                  gain=1, layer_norm=False, use_fork=False, mode='full'):
#         super().__init__()
#         self.modes = modes
#         self.d_model = d_model
#         self.drop = nn.Dropout(in_dropout)
#         self.n_layers = n_layers
#         # self.use_fork = use_fork

#         self.forecast_ff = self.backcast_ff = None
#         # if share_fork:
#         #     if use_fork:
#         #         self.forecast_ff = FeedForward(
#         #             d_model, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)
#         #     self.backcast_ff = FeedForward(
#         #         d_model, factor, ff_weight_norm, n_ff_layers, layer_norm, dropout)

#         # self.fourier_weight = None
#         # if share_weight:
#         #     self.fourier_weight = nn.ParameterList([])
#         #     for _ in range(2):
#         #         weight = torch.FloatTensor(width, width, modes, 2)
#         #         param = nn.Parameter(weight)
#         #         nn.init.xavier_normal_(param, gain=gain)
#         #         self.fourier_weight.append(param)

#         self.spectral_layers = nn.ModuleList([])
#         for _ in range(n_layers):
#             self.spectral_layers.append(SpectralConv2d(in_dim=d_model,
#                                                        out_dim=d_model,
#                                                        n_modes=modes,
#                                                        forecast_ff=self.forecast_ff,
#                                                        backcast_ff=self.backcast_ff,
#                                                        fourier_weight=self.fourier_weight,
#                                                        factor=factor,
#                                                        ff_weight_norm=ff_weight_norm,
#                                                        n_ff_layers=n_ff_layers,
#                                                        layer_norm=layer_norm,
#                                                        use_fork=use_fork,
#                                                        dropout=dropout,
#                                                        mode=mode))

#         # self.out = nn.Sequential(
#         #     nn.Linear(self.d_model, 128),
#         #     nn.Linear(128, 1))

#     def forward(self, x, **kwargs):
#         # x.shape == [n_batches, *dim_sizes, input_size]
#         forecast = 0
#         x = self.in_proj(x)
#         x = self.drop(x)
#         forecast_list = []
#         for i in range(self.n_layers):
#             layer = self.spectral_layers[i]
#             b, f = layer(x)

#             if self.use_fork:
#                 f_out = self.out(f)
#                 forecast = forecast + f_out
#                 forecast_list.append(f_out)

#             x = x + b

#         if not self.use_fork:
#             forecast = self.out(b)

#         return {
#             'forecast': forecast,
#             'forecast_list': forecast_list,
#         }
