import torch
import torch.nn as nn
from ..utils import check_consistency
from .layers.fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D


class FNO(torch.nn.Module):
    """
    The PINA implementation of Fourier Neural Operator network.

    Fourier Neural Operator (FNO) is a general architecture for learning Operators.
    Unlike traditional machine learning methods FNO is designed to map
    entire functions to other functions. It can be trained both with 
    Supervised learning strategies. FNO does global convolution by performing the
    operation on the Fourier space.

    .. seealso::

        **Original reference**: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B.,
        Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). "Fourier neural operator for
        parametric partial differential equations".
        DOI: `arXiv preprint arXiv:2010.08895.
        <https://arxiv.org/abs/2010.08895>`_
    """
    def __init__(self,
                 lifting_net, 
                 projecting_net,
                 n_modes,
                 dimensions = 3,
                 padding = 8,
                 padding_type = "constant",
                 inner_size = 20,
                 n_layers = 2,
                 func = nn.Tanh,
                 layers = None):
        super().__init__()

        # check type consistency
        check_consistency(lifting_net, nn.Module)
        check_consistency(projecting_net, nn.Module)
        check_consistency(dimensions, int)
        check_consistency(padding, int)
        check_consistency(padding_type, str)
        check_consistency(inner_size, int)
        check_consistency(n_layers, int)
        check_consistency(func, nn.Module, subclass=True)
        if layers is not None:
            if isinstance(layers, (tuple, list)):
                check_consistency(layers, int)
            else:
                raise ValueError('layers must be tuple or list of int.')
        if not isinstance(n_modes, (list, tuple, int)):
            raise ValueError('n_modes must be a int or list or tuple of valid modes.'
                             ' More information on the official documentation.')
        
        # assign variables 
        # TODO check input lifting net and input projecting net
        self._lifting_net = lifting_net
        self._projecting_net = projecting_net
        self._padding = padding

        # initialize fourier layer for each dimension
        if dimensions == 1:
            fourier_layer = FourierBlock1D
        elif dimensions == 2:
            fourier_layer = FourierBlock2D
        elif dimensions == 3:
            fourier_layer = FourierBlock3D
        else:
            NotImplementedError('FNO implemented only for 1D/2D/3D data.')

        # Here we build the FNO by stacking Fourier Blocks

        # 1. Assign output dimensions for each FNO layer
        if layers is None:
            layers = [inner_size] * n_layers

        # 2. Assign activation functions for each FNO layer
        if isinstance(func, list):
            if len(layers) != len(func):
                raise RuntimeError('Uncosistent number of layers and functions.')
            self._functions = func
        else:
            self._functions = [func for _ in range(len(layers))]

        # 3. Assign modes functions for each FNO layer
        if isinstance(n_modes, list):
            if all(isinstance(i, list) for i in n_modes) and len(layers) != len(n_modes):
                raise RuntimeError('Uncosistent number of layers and functions.')
            elif all(isinstance(i, int) for i in n_modes):
                n_modes = [n_modes] * len(layers)
        else:
            n_modes = [n_modes] * len(layers)

        # 4. Build the FNO network
        tmp_layers = layers.copy()
        first_parameter = next(lifting_net.parameters())
        input_shape = first_parameter.size()
        out_feats = lifting_net(torch.rand(size=input_shape)).shape[-1]
        tmp_layers.insert(0, out_feats)

        self._layers = []
        for i in range(len(tmp_layers) - 1):
            self._layers.append(
                fourier_layer(input_numb_fields=tmp_layers[i],
                              output_numb_fields=tmp_layers[i + 1],
                              n_modes=n_modes[i],
                              activation=self._functions[i])
            )
        self._layers = nn.Sequential(*self._layers)

        # 5. Padding values for spectral conv
        if isinstance(padding, int):
            padding = [padding] * dimensions
        self._ipad = [-pad if pad > 0 else None for pad in padding[:dimensions]]
        self._padding_type = padding_type
        self._pad = [val for pair in zip([0]*dimensions, padding) for val in pair]



    def forward(self, x):
        """
        Forward computation for Fourier Neural Operator. It performs a
        lifting of the input by the ``lifting_net``. Then different layers
        of Fourier Blocks are applied. Finally the output is projected
        to the final dimensionality by the ``projecting_net``.

        :param torch.Tensor x: The input tensor for fourier block, depending on
            ``dimension`` in the initialization. In particular it is expected
            * 1D tensors: ``[batch, X, channels]``
            * 2D tensors: ``[batch, X, Y, channels]``
            * 3D tensors: ``[batch, X, Y, Z, channels]``
        :return: The output tensor obtained from the FNO.
        :rtype: torch.Tensor
        """

        # lifting the input in higher dimensional space
        x = self._lifting_net(x)
        
        # permuting the input [batch, channels, x, y, ...]
        permutation_idx = [0, x.ndim-1, *[i for i in range(1, x.ndim-1)]]
        x = x.permute(permutation_idx)

        # padding the input
        x = torch.nn.functional.pad(x, pad=self._pad, mode=self._padding_type)

        # apply fourier layers
        x = self._layers(x)

        # remove padding
        idxs = [slice(None), slice(None)] + [slice(pad) for pad in self._ipad] 
        x = x[idxs]

        # permuting back [batch, x, y, ..., channels]
        permutation_idx = [0, *[i for i in range(2, x.ndim)], 1]
        x = x.permute(permutation_idx)

        # apply projecting operator and return
        return self._projecting_net(x)
