"""Module for Continuous Convolution class"""
from .convolution import BaseContinuousConv
from .utils_convolution import check_point, map_points_
from .integral import Integral
import torch


class ContinuousConvBlock(BaseContinuousConv):
    """
    Implementation of Continuous Convolutional operator.

    .. seealso::

        **Original reference**: Coscia, D., Meneghetti, L., Demo, N.,
        Stabile, G., & Rozza, G.. (2022). A Continuous Convolutional Trainable
        Filter for Modelling Unstructured Data.
        DOI: `10.48550/arXiv.2210.13416
        <https://doi.org/10.48550/arXiv.2210.13416>`_.

    """

    def __init__(self, input_numb_field, output_numb_field,
                 filter_dim, stride, model=None, optimize=False,
                 no_overlap=False):
        """

        :param input_numb_field: Number of fields N_in in the input.
        :type input_numb_field: int
        :param output_numb_field: Number of fields N_out  in the output.
        :type output_numb_field: int
        :param filter_dim: Dimension of the filter.
        :type filter_dim: tuple/ list
        :param stride: Stride for the filter.
        :type stride: dict
        :param model: Neural network for inner parametrization,
            defaults to None. If None, a default multilayer perceptron
            is used, see BaseContinuousConv.DefaultKernel.
        :type model: torch.nn.Module, optional
        :param optimize: Flag for performing optimization on the continuous
            filter, defaults to False. The flag `optimize=True` should be
            used only when the scatter datapoints are fixed through the
            training. If torch model is in `.eval()` mode, the flag is
            automatically set to False always.
        :type optimize: bool, optional
        :param no_overlap: Flag for performing optimization on the transpose
            continuous filter, defaults to False. The flag set to `True` should
            be used only when the filter positions do not overlap for different
            strides. RuntimeError will raise in case of non-compatible strides.
        :type no_overlap: bool, optional

        .. note::
            Using `optimize=True` the filter can be use either in `forward`
            or in `transpose` mode, not both. If `optimize=False` the same
            filter can be used for both `transpose` and `forward` modes.

        .. warning::
            The algorithm expects input to be in the form: [B x N_in x N x D]
            where B is the batch_size, N_in is the number of input
            fields, N the number of points in the mesh, D the dimension
            of the problem. In particular:

            *   D is the number of spatial variables + 1. The last column must
                contain the field value. For example for 2D problems D=3 and
                the tensor will be something like `[first coordinate, second
                coordinate, field value]`.

            *   N_in represents the number of vectorial function presented.
                For example a vectorial function f = [f_1, f_2] will have
                N_in=2.

            The algorithm returns a tensor of shape: [B x N_out x N x D]
            where B is the batch_size, N_out is the number of output
            fields, N' the number of points in the mesh, D the dimension
            of the problem (coordinates + field value).

            For example, a 2-dimensional vectorial function N_in=2 of
            3-dimensionalcinput D=3+1=4 with 100 points input mesh and batch
            size of 8 is represented as a tensor `[8, 2, 100, 4]`, where the
            columnsc`[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
            second filed value respectively.

        :Example:
            >>> class MLP(torch.nn.Module):
                    def __init__(self) -> None:
                        super().__init__()
                        self. model = torch.nn.Sequential(
                                                        torch.nn.Linear(2, 8),
                                                        torch.nn.ReLU(),
                                                        torch.nn.Linear(8, 8),
                                                        torch.nn.ReLU(),
                                                        torch.nn.Linear(8, 1))
                    def forward(self, x):
                        return self.model(x)
            >>> dim = [3, 3]
            >>> stride = {"domain": [10, 10],
                          "start": [0, 0],
                          "jumps": [3, 3],
                          "direction": [1, 1.]}
            >>> conv = ContinuousConv2D(1, 2, dim, stride, MLP)
            >>> conv
                ContinuousConv2D(
                (_net): ModuleList(
                    (0): MLP(
                    (model): Sequential(
                        (0): Linear(in_features=2, out_features=8, bias=True)
                        (1): ReLU()
                        (2): Linear(in_features=8, out_features=8, bias=True)
                        (3): ReLU()
                        (4): Linear(in_features=8, out_features=1, bias=True)
                    )
                    )
                    (1): MLP(
                    (model): Sequential(
                        (0): Linear(in_features=2, out_features=8, bias=True)
                        (1): ReLU()
                        (2): Linear(in_features=8, out_features=8, bias=True)
                        (3): ReLU()
                        (4): Linear(in_features=8, out_features=1, bias=True)
                    )
                    )
                )
                )
        """

        super().__init__(input_numb_field=input_numb_field,
                         output_numb_field=output_numb_field,
                         filter_dim=filter_dim,
                         stride=stride,
                         model=model,
                         optimize=optimize,
                         no_overlap=no_overlap)

        # integral routine
        self._integral = Integral('discrete')

        # create the network
        self._net = self._spawn_networks(model)

        # stride for continuous convolution overridden
        self._stride = self._stride._stride_discrete

    def _spawn_networks(self, model):
        """Private method to create a collection of kernels

        :param model: a torch.nn.Module model in form of Object class
        :type model: torch.nn.Module
        :return: list of torch.nn.Module models
        :rtype: torch.nn.ModuleList

        """
        nets = []
        if self._net is None:
            for _ in range(self._input_numb_field * self._output_numb_field):
                tmp = ContinuousConvBlock.DefaultKernel(len(self._dim), 1)
                nets.append(tmp)
        else:
            if not isinstance(model, object):
                raise ValueError("Expected a python class inheriting"
                                 " from torch.nn.Module")

            for _ in range(self._input_numb_field * self._output_numb_field):
                tmp = model()
                if not isinstance(tmp, torch.nn.Module):
                    raise ValueError("The python class must be inherited from"
                                     " torch.nn.Module. See the docstring for"
                                     " an example.")
                nets.append(tmp)

        return torch.nn.ModuleList(nets)

    def _extract_mapped_points(self, batch_idx, index, x):
        """Priviate method to extract mapped points in the filter

        :param x: input tensor [channel x N x dim]
        :type x: torch.tensor
        :return: mapped points and indeces for each channel
        :rtype: tuple(torch.tensor, list)

        """
        mapped_points = []
        indeces_channels = []

        for stride_idx, current_stride in enumerate(self._stride):

            # indeces of points falling into filter range
            indeces = index[stride_idx][batch_idx]

            # how many points for each channel fall into the filter?
            numb_points_insiede = torch.sum(indeces, dim=-1).tolist()

            # extracting points for each channel
            # shape: [sum(numb_points_insiede), filter_dim + 1]
            point_stride = x[indeces]

            # mapping points in filter domain
            map_points_(point_stride[..., :-1], current_stride)

            # extracting points for each channel
            point_stride_channel = point_stride.split(numb_points_insiede)

            # appending in list for later use
            mapped_points.append(point_stride_channel)
            indeces_channels.append(numb_points_insiede)

        # stacking input for passing to neural net
        mapping = map(torch.cat, zip(*mapped_points))
        stacked_input = tuple(mapping)
        indeces_channels = tuple(zip(*indeces_channels))

        return stacked_input, indeces_channels

    def _find_index(self, X):
        """Private method to extract indeces for convolution.

        :param X: input tensor, as in ContinuousConv2D docstring
        :type X: torch.tensor

        """
        # append the index for each stride
        index = []
        for _, current_stride in enumerate(self._stride):

            tmp = check_point(X, current_stride, self._dim)
            index.append(tmp)

        # storing the index
        self._index = index

    def _make_grid_forward(self, X):
        """Private method to create forward convolution grid.

        :param X: input tensor, as in ContinuousConv2D docstring
        :type X: torch.tensor

        """
        # filter dimension + number of points in output grid
        filter_dim = len(self._dim)
        number_points = len(self._stride)

        # initialize the grid
        grid = torch.zeros(size=(X.shape[0],
                                 self._output_numb_field,
                                 number_points,
                                 filter_dim + 1),
                           device=X.device,
                           dtype=X.dtype)
        grid[..., :-1] = (self._stride + self._dim * 0.5)

        # saving the grid
        self._grid = grid.detach()

    def _make_grid_transpose(self, X):
        """Private method to create transpose convolution grid.

        :param X: input tensor, as in ContinuousConv2D docstring
        :type X: torch.tensor

        """
        # initialize to all zeros
        tmp = torch.zeros_like(X)
        tmp[..., :-1] = X[..., :-1]

        # save on tmp
        self._grid_transpose = tmp

    def _make_grid(self, X, type):
        """Private method to create convolution grid.

        :param X: input tensor, as in ContinuousConv2D docstring
        :type X: torch.tensor
        :param type: type of convolution, ['forward', 'inverse'] the
            possibilities
        :type type: string

        """
        # choose the type of convolution
        if type == 'forward':
            return self._make_grid_forward(X)
        elif type == 'inverse':
            self._make_grid_transpose(X)
        else:
            raise TypeError

    def _initialize_convolution(self, X, type='forward'):
        """Private method to intialize the convolution.
        The convolution is initialized by setting a grid and
        calculate the index for finding the points inside the
        filter.

        :param X: input tensor, as in ContinuousConv2D docstring
        :type X: torch.tensor
        :param type: type of convolution, ['forward', 'inverse'] the
            possibilities
        :type type: string
        """

        # variable for the convolution
        self._make_grid(X, type)

        # calculate the index
        self._find_index(X)

    def forward(self, X):
        """Forward pass in the layer

        :param x: input data (input_numb_field x N x filter_dim)
        :type x: torch.tensor
        :return: feed forward convolution (output_numb_field x N x filter_dim)
        :rtype: torch.tensor
        """

        # initialize convolution
        if self.training:  # we choose what to do based on optimization
            self._choose_initialization(X, type='forward')

        else:  # we always initialize on testing
            self._initialize_convolution(X, 'forward')

        # create convolutional array
        conv = self._grid.clone().detach()

        # total number of fields
        tot_dim = self._output_numb_field * self._input_numb_field

        for batch_idx, x in enumerate(X):

            # extract mapped points
            stacked_input, indeces_channels = self._extract_mapped_points(
                batch_idx, self._index, x)

            # compute the convolution

            # storing intermidiate results for each channel convolution
            res_tmp = []
            # for each field
            for idx_conv in range(tot_dim):
                # index for each input field
                idx = idx_conv % self._input_numb_field
                # extract input for each channel
                single_channel_input = stacked_input[idx]
                # extract filter
                net = self._net[idx_conv]
                # calculate filter value
                staked_output = net(single_channel_input[..., :-1])
                # perform integral for all strides in one field
                integral = self._integral(staked_output,
                                          single_channel_input[..., -1],
                                          indeces_channels[idx])
                res_tmp.append(integral)

            # stacking integral results
            res_tmp = torch.stack(res_tmp)

            # sum filters (for each input fields) in groups
            # for different ouput fields
            conv[batch_idx, ..., -1] = res_tmp.reshape(self._output_numb_field,
                                                       self._input_numb_field,
                                                       -1).sum(1)
        return conv

    def transpose_no_overlap(self, integrals, X):
        """Transpose pass in the layer for no-overlapping filters

        :param integrals: Weights for the transpose convolution. Shape
            [B x N_in x N]
            where B is the batch_size, N_in is the number of input
            fields, N the number of points in the mesh, D the dimension
            of the problem.
        :type integral: torch.tensor
        :param X: Input data. Expect tensor of shape
            [B x N_in x M x D] where B is the batch_size,
            N_in is the number of input fields, M the number of points
            in the mesh, D the dimension of the problem. Note, last column
        :type X: torch.tensor
        :return: Feed forward transpose convolution. Tensor of shape
            [B x N_out x N] where B is the batch_size,
            N_out is the number of output fields, N the number of points
            in the mesh, D the dimension of the problem.
        :rtype: torch.tensor

        .. note::
            This function is automatically called when `.transpose()`
            method is used and `no_overlap=True`
        """

        # initialize convolution
        if self.training:  # we choose what to do based on optimization
            self._choose_initialization(X, type='inverse')

        else:  # we always initialize on testing
            self._initialize_convolution(X, 'inverse')

        # initialize grid
        X = self._grid_transpose.clone().detach()
        conv_transposed = self._grid_transpose.clone().detach()

        # total number of dim
        tot_dim = self._input_numb_field * self._output_numb_field

        for batch_idx, x in enumerate(X):

            # extract mapped points
            stacked_input, indeces_channels = self._extract_mapped_points(
                batch_idx, self._index, x)

            # compute the transpose convolution

            # total number of fields
            res_tmp = []

            # for each field
            for idx_conv in range(tot_dim):
                # index for each output field
                idx = idx_conv % self._output_numb_field
                # index for each input field
                idx_in = idx_conv % self._input_numb_field
                # extract input for each field
                single_channel_input = stacked_input[idx]
                rep_idx = torch.tensor(indeces_channels[idx])
                integral = integrals[batch_idx,
                                     idx_in, :].repeat_interleave(rep_idx)
                # extract filter
                net = self._net[idx_conv]
                # perform transpose convolution for all strides in one field
                staked_output = net(single_channel_input[..., :-1]).flatten()
                integral = staked_output * integral
                res_tmp.append(integral)

            # stacking integral results and sum
            # filters (for each input fields) in groups
            # for different output fields
            res_tmp = torch.stack(res_tmp).reshape(self._input_numb_field,
                                                   self._output_numb_field,
                                                   -1).sum(0)
            conv_transposed[batch_idx, ..., -1] = res_tmp

        return conv_transposed

    def transpose_overlap(self, integrals, X):
        """Transpose pass in the layer for overlapping filters

        :param integrals: Weights for the transpose convolution. Shape
            [B x N_in x N]
            where B is the batch_size, N_in is the number of input
            fields, N the number of points in the mesh, D the dimension
            of the problem.
        :type integral: torch.tensor
        :param X: Input data. Expect tensor of shape
            [B x N_in x M x D] where B is the batch_size,
            N_in is the number of input fields, M the number of points
            in the mesh, D the dimension of the problem. Note, last column
        :type X: torch.tensor
        :return: Feed forward transpose convolution. Tensor of shape
            [B x N_out x N] where B is the batch_size,
            N_out is the number of output fields, N the number of points
            in the mesh, D the dimension of the problem.
        :rtype: torch.tensor

        .. note:: This function is automatically called when `.transpose()`
            method is used and `no_overlap=False`
        """

        # initialize convolution
        if self.training:  # we choose what to do based on optimization
            self._choose_initialization(X, type='inverse')

        else:  # we always initialize on testing
            self._initialize_convolution(X, 'inverse')

        # initialize grid
        X = self._grid_transpose.clone().detach()
        conv_transposed = self._grid_transpose.clone().detach()

        # list to iterate for calculating nn output
        tmp = [i for i in range(self._output_numb_field)]
        iterate_conv = [item for item in tmp for _ in range(
            self._input_numb_field)]

        for batch_idx, x in enumerate(X):

            # accumulator for the convolution on different batches
            accumulator_batch = torch.zeros(
                size=(self._grid_transpose.shape[1],
                      self._grid_transpose.shape[2]),
                requires_grad=True,
                device=X.device,
                dtype=X.dtype).clone()

            for stride_idx, current_stride in enumerate(self._stride):
                # indeces of points falling into filter range
                indeces = self._index[stride_idx][batch_idx]

                # number of points for each channel
                numb_pts_channel = tuple(indeces.sum(dim=-1))

                # extracting points for each channel
                point_stride = x[indeces]

                # if no points to upsample we just skip
                if point_stride.nelement() == 0:
                    continue

                # mapping points in filter domain
                map_points_(point_stride[..., :-1], current_stride)

                # input points for kernels
                # we split for extracting number of points for each channel
                nn_input_pts = point_stride[..., :-1].split(numb_pts_channel)

                # accumulate partial convolution results for each field
                res_tmp = []

                # for each channel field compute transpose convolution
                for idx_conv, idx_channel_out in enumerate(iterate_conv):

                    # index for input channels
                    idx_channel_in = idx_conv % self._input_numb_field

                    # extract filter
                    net = self._net[idx_conv]

                    # calculate filter value
                    staked_output = net(nn_input_pts[idx_channel_out])

                    # perform integral for all strides in one field
                    integral = staked_output * integrals[batch_idx,
                                                         idx_channel_in,
                                                         stride_idx]
                    # append results
                    res_tmp.append(integral.flatten())

                # computing channel sum
                channel_sum = []
                start = 0
                for _ in range(self._output_numb_field):
                    tmp = res_tmp[start:start + self._input_numb_field]
                    tmp = torch.vstack(tmp).sum(dim=0)
                    channel_sum.append(tmp)
                    start += self._input_numb_field

                # accumulate the results
                accumulator_batch[indeces] += torch.hstack(channel_sum)

            # save results of accumulation for each batch
            conv_transposed[batch_idx, ..., -1] = accumulator_batch

        return conv_transposed
