"""Module for Base Continuous Convolution class."""
from abc import ABCMeta, abstractmethod
import torch
from .stride import Stride
from .utils_convolution import optimizing


class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta):
    """
    Abstract class
    """

    def __init__(self, input_numb_field, output_numb_field,
                 filter_dim, stride, model=None, optimize=False,
                 no_overlap=False):
        """Base Class for Continuous Convolution.

        The algorithm expects input to be in the form:
        $$[B \times N_{in} \times N \times D]$$
        where $B$ is the batch_size, $N_{in}$ is the number of input
        fields, $N$ the number of points in the mesh, $D$ the dimension
        of the problem. In particular:
        * $D$ is the number of spatial variables + 1. The last column must
            contain the field value. For example for 2D problems $D=3$ and
            the tensor will be something like `[first coordinate, second
            coordinate, field value]`.
        * $N_{in}$ represents the number of vectorial function presented.
            For example a vectorial function $f = [f_1, f_2]$ will have
            $N_{in}=2$.

        :Note
            A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
            input $D=3+1=4$ with 100 points input mesh and batch size of 8
            is represented as a tensor `[8, 2, 100, 4]`, where the columns
            `[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
            second filed value respectively

        The algorithm returns a tensor of shape:
        $$[B \times N_{out} \times N' \times D]$$
        where $B$ is the batch_size, $N_{out}$ is the number of output
        fields, $N'$ the number of points in the mesh, $D$ the dimension
        of the problem.

        :param input_numb_field: number of fields in the input
        :type input_numb_field: int
        :param output_numb_field: number of fields in the output
        :type output_numb_field: int
        :param filter_dim: dimension of the filter
        :type filter_dim: tuple/ list
        :param stride: stride for the filter
        :type stride: dict
        :param model: neural network for inner parametrization,
        defaults to None
        :type model: torch.nn.Module, optional
        :param optimize: flag for performing optimization on the continuous
            filter, defaults to False. The flag `optimize=True` should be
            used only when the scatter datapoints are fixed through the
            training. If torch model is in `.eval()` mode, the flag is
            automatically set to False always.
        :type optimize: bool, optional
        :param no_overlap: flag for performing optimization on the transpose
            continuous filter, defaults to False. The flag set to `True` should
            be used only when the filter positions do not overlap for different
            strides. RuntimeError will raise in case of non-compatible strides.
        :type no_overlap: bool, optional
        """
        super().__init__()

        if isinstance(input_numb_field, int):
            self._input_numb_field = input_numb_field
        else:
            raise ValueError('input_numb_field must be int.')

        if isinstance(output_numb_field, int):
            self._output_numb_field = output_numb_field
        else:
            raise ValueError('input_numb_field must be int.')

        if isinstance(filter_dim, (tuple, list)):
            vect = filter_dim
        else:
            raise ValueError('filter_dim must be tuple or list.')
        vect = torch.tensor(vect)
        self.register_buffer("_dim", vect, persistent=False)

        if isinstance(stride, dict):
            self._stride = Stride(stride)
        else:
            raise ValueError('stride must be dictionary.')

        self._net = model

        if isinstance(optimize, bool):
            self._optimize = optimize
        else:
            raise ValueError('optimize must be bool.')

        # choosing how to initialize based on optimization
        if self._optimize:
            # optimizing decorator ensure the function is called
            # just once
            self._choose_initialization = optimizing(
                self._initialize_convolution)
        else:
            self._choose_initialization = self._initialize_convolution

        if not isinstance(no_overlap, bool):
            raise ValueError('no_overlap must be bool.')

        if no_overlap:
            raise NotImplementedError
            self.transpose = self.transpose_no_overlap
        else:
            self.transpose = self.transpose_overlap

    class DefaultKernel(torch.nn.Module):
        def __init__(self, input_dim, output_dim):
            super().__init__()
            assert isinstance(input_dim, int)
            assert isinstance(output_dim, int)
            self._model = torch.nn.Sequential(
                                            torch.nn.Linear(input_dim, 20),
                                            torch.nn.ReLU(),
                                            torch.nn.Linear(20, 20),
                                            torch.nn.ReLU(),
                                            torch.nn.Linear(20, output_dim)
                                            )
        def forward(self, x):
            return self._model(x)

    @ property
    def net(self):
        return self._net

    @ property
    def stride(self):
        return self._stride

    @ property
    def filter_dim(self):
        return self._dim

    @ property
    def input_numb_field(self):
        return self._input_numb_field

    @ property
    def output_numb_field(self):
        return self._output_numb_field

    @property
    @abstractmethod
    def forward(self, X):
        pass

    @property
    @abstractmethod
    def transpose_overlap(self, X):
        pass

    @property
    @abstractmethod
    def transpose_no_overlap(self, X):
        pass

    @property
    @abstractmethod
    def _initialize_convolution(self, X, type):
        pass
