# Copyright 2019 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import functional as Fsp
from . import ops
from .modules import SparseModule
from .structure import SparseConvTensor


class SparseMaxPool(SparseModule):
    def __init__(self, ndim, kernel_size, stride=1, padding=0, dilation=1, subm=False):
        super(SparseMaxPool, self).__init__()
        if not isinstance(kernel_size, (list, tuple)):
            kernel_size = [kernel_size] * ndim
        if not isinstance(stride, (list, tuple)):
            stride = [stride] * ndim
        if not isinstance(padding, (list, tuple)):
            padding = [padding] * ndim
        if not isinstance(dilation, (list, tuple)):
            dilation = [dilation] * ndim

        self.ndim = ndim
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.subm = subm
        self.dilation = dilation

    def forward(self, input):
        assert isinstance(input, SparseConvTensor)
        features = input.features
        device = features.device
        indices = input.indices
        spatial_shape = input.spatial_shape
        batch_size = input.batch_size
        if not self.subm:
            out_spatial_shape = ops.get_conv_output_size(
                spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation
            )
        else:
            out_spatial_shape = spatial_shape
        outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
            indices,
            batch_size,
            spatial_shape,
            self.kernel_size,
            self.stride,
            self.padding,
            self.dilation,
            0,
            self.subm,
        )

        out_features = Fsp.indice_maxpool(
            features, indice_pairs.to(device), indice_pairs_num.to(device), outids.shape[0]
        )
        out_tensor = SparseConvTensor(out_features, outids, out_spatial_shape, batch_size)
        out_tensor.indice_dict = input.indice_dict
        out_tensor.grid = input.grid
        return out_tensor


class SparseMaxPool2d(SparseMaxPool):
    def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
        super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding, dilation)


class SparseMaxPool3d(SparseMaxPool):
    def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
        super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding, dilation)
