import einops
from torch.nn.functional import conv2d, pad

import e2cnn
from e2cnn import nn
from e2cnn.nn import init
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor
from e2cnn.nn import MaskModule
from e2cnn.nn import EquivariantModule
from e2cnn.gspaces import *

from typing import Callable, Union, Tuple, List, Optional

import torch
from torch.nn import Parameter
from torch.nn import functional as F
import numpy as np
from scipy.linalg import circulant
import math


__all__ = ["DefaultR2Conv", "OursR2Conv"]


class DefaultR2Conv(nn.R2Conv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        io_pair = self.basisexpansion._representations_pairs[0]
        block_expansion = getattr(self.basisexpansion, f"block_expansion_{io_pair}")
        sampled_basis = block_expansion.sampled_basis

        id_basis = torch.zeros_like(sampled_basis[0])
        d = self.kernel_size  # XXX Assume square kernels
        origin = (d // 2) * d + (d // 2)
        if id_basis.size(0) == id_basis.size(1):  # If repr_in -> repr_out both regular
            id_basis[:, :, origin] = torch.eye(id_basis.size(0), dtype=sampled_basis.dtype)
        else:  # XXX Assume that if not from reg to reg reprs, then we go from trivial to any
            id_basis[:, :, origin].fill_(1.)
        self.register_buffer('id_basis', id_basis.view(id_basis.size(0), id_basis.size(1), d, d))   # (r_out, r_in, kernel_size, kernel_size)

        N = self.basisexpansion.dimension() + len(self.in_type) * len(self.out_type)
        self.register_parameter("weights", torch.nn.Parameter(torch.zeros(N), requires_grad=True))
        self.register_buffer("filter", torch.zeros(self.out_type.size, self.in_type.size, self.kernel_size, self.kernel_size),
                             persistent=False)
        std = 1. / math.sqrt(self.in_type.size)
        torch.nn.init.normal_(self.weights, mean=0.0, std=std)

    def expand_parameters(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""

        Expand the filter in terms of the :attr:`e2cnn.nn.R2Conv.weights` and the
        expanded bias in terms of :class:`e2cnn.nn.R2Conv.bias`.

        Returns:
            the expanded filter and bias

        """
        reshaped_weights = self.weights.view(len(self.out_type), len(self.in_type), -1)
        def_weights = reshaped_weights[:, :, 1:]
        id_weights = reshaped_weights[:, :, 0]
        _filter = self.basisexpansion(def_weights.reshape(-1))
        _filter = _filter.reshape(self.out_type.size, self.in_type.size, self.kernel_size, self.kernel_size)
        id_filter = id_weights[:, :, None, None, None, None] * self.id_basis
        id_filter = id_filter.transpose(1, 2).reshape(self.out_type.size, self.in_type.size, self.kernel_size, self.kernel_size)
        _filter = _filter + id_filter

        if self.bias is None:
            _bias = None
        else:
            _bias = self.bias_expansion @ self.bias

        return _filter, _bias

    def train(self, mode=True):
        if mode:
            # TODO thoroughly check this is not causing problems
            if hasattr(self, "filter"):
                del self.filter
            if hasattr(self, "expanded_bias"):
                del self.expanded_bias
        elif self.training:
            # avoid re-computation of the filter and the bias on multiple consecutive calls of `.eval()`

            _filter, _bias = self.expand_parameters()

            self.register_buffer("filter", _filter, persistent=False)
            if _bias is not None:
                self.register_buffer("expanded_bias", _bias)
            else:
                self.expanded_bias = None

        return super(nn.R2Conv, self).train(mode)


def build_radius_mask(kernel_size, margin=0., decay=1., dtype=torch.float32):
    mask = torch.zeros(1, 1, 1, kernel_size, kernel_size, dtype=dtype)
    c = (kernel_size-1) / 2
    t = (c - margin/100.*c)**2
    for x in range(kernel_size):
        for y in range(kernel_size):
            r = (x - c) ** 2 + (y - c) ** 2
            if r > t:
                mask[..., x, y] = math.exp((t - r)/decay)
            else:
                mask[..., x, y] = 1.
    return mask


def build_edge_mask(kernel_size, modulo, margin=0.5, decay=0.75, dtype=torch.float32):
    mask = torch.zeros(1, 1, 1, kernel_size, kernel_size, dtype=dtype)
    c = kernel_size // 2
    modulo = float(modulo)
    t = margin
    for x in range(kernel_size):
        for y in range(kernel_size):
            r = ((y - c) ** 2 + (x - c) ** 2) ** 0.5
            angle = math.atan2(y - c, x - c) % modulo
            arc = r * angle
            if arc < t:
                mask[..., x, y] = math.exp((arc - t)/decay)
            elif modulo * r - arc < t:
                mask[..., x, y] = math.exp((modulo * r - t - arc)/decay)
            else:
                mask[..., x, y] = 1.
    return mask


def build_basis(kernel_size, r2_act, in_type, out_type,
                essentially_zero=1e-2, upfactor=3,
                margin_radius=0., decay_radius=1.,
                margin_arc_edges=0.5, decay_arc_edges=0.75):
    card = r2_act.fibergroup.order()
    origin = kernel_size // 2
    modulo = 2 * math.pi / card
    size_up = kernel_size * upfactor
    is_inp_trivial = in_type.representations[0].is_trivial()  # HACK
    is_out_trivial = out_type.representations[0].is_trivial()  # HACK
    if is_inp_trivial:
        rho_in = 1
    else:
        rho_in = card
    if is_out_trivial:
        out_type = FieldType(r2_act, rho_in * [r2_act.trivial_repr,])
        rho_out = 1
    else:
        out_type = FieldType(r2_act, rho_in * [r2_act.regular_repr,])
        rho_out = card
    idx = torch.arange(rho_in * rho_out)
    in_repr = r2_act.trivial_repr if is_inp_trivial else r2_act.regular_repr
    out_repr = r2_act.trivial_repr if is_out_trivial else r2_act.regular_repr
    bs = []
    for x, y in torch.cartesian_prod(torch.arange(kernel_size), torch.arange(kernel_size)):
        x_ = x - origin
        y_ = y - origin
        angle = math.atan2(y_, x_)
        if not (y_ >= 0 and x_ >= 0 and angle <= modulo and angle >= 0):
            continue
        if y_ == 0 and x_ == 0:  # skip the origin
            continue
        k = torch.zeros(rho_out*rho_in, rho_out*rho_in, kernel_size, kernel_size)
        k[idx, idx, x, y] = 1.
        kup = F.interpolate(k, size_up, mode='nearest')
        bup = sum(torch.einsum('bki...,kl->bil...',
            GeometricTensor(kup, out_type).transform(i).tensor.view(-1, rho_out, rho_in, size_up, size_up),
            torch.from_numpy(in_repr(r2_act.fibergroup.inverse(i))).to(dtype=torch.float32)
            ) for i in r2_act.testing_elements)
        b = F.interpolate(einops.rearrange(bup, 'b i l h w -> b (i l) h w'), kernel_size, mode='bilinear', align_corners=True)
        b = b.view(-1, rho_out, rho_in, kernel_size, kernel_size)
        b = torch.where(b < essentially_zero, torch.tensor(0.), b)
        bs.append(b)

    if len(bs) > 0:
        bs = torch.cat(bs)
        # mask off the radius so that rotations are fully supported
        radius_mask = build_radius_mask(kernel_size, margin=margin_radius, decay=decay_radius)
        bs *= radius_mask

        # mask discount close to the edges
        # so that we account for discretization error
        edge_mask = build_edge_mask(kernel_size, modulo, margin=margin_arc_edges, decay=decay_arc_edges)
        bs *= edge_mask

    # include bases for the origin
    def generate_circulant(dim, dim2=None):
        dim2 = dim2 or dim
        A = torch.from_numpy(circulant(np.arange(dim))).to(torch.float32)
        for i in range(dim2):
            yield A.eq(i).float()

    if isinstance(r2_act.fibergroup, e2cnn.group.DihedralGroup):
        if is_inp_trivial or is_out_trivial:
            k = torch.zeros(1, rho_out*rho_in, kernel_size, kernel_size)
            k[0, :, origin, origin] = 1.
            k = k.view(-1, rho_out, rho_in, kernel_size, kernel_size)
        else:
            k = torch.zeros((card // 2 + 1)**2, card, card, kernel_size, kernel_size)
            for i, a in enumerate(generate_circulant(card // 2, card // 2 + 1)):
                # reversed, so that zero element is the identity
                for j, b in enumerate(reversed(list(generate_circulant(card // 2, card // 2 + 1)))):
                    k[i*(card // 2 + 1)+j, :, :, origin, origin] = torch.cat([torch.cat([a, b]), torch.cat([b, a])], dim=1)
            k = torch.where(k < essentially_zero, torch.tensor(0.), k)
    elif isinstance(r2_act.fibergroup, e2cnn.group.CyclicGroup):
        if is_inp_trivial or is_out_trivial:
            k = torch.zeros(1, rho_out*rho_in, kernel_size, kernel_size)
            k[0, :, origin, origin] = 1.
            k = k.view(-1, rho_out, rho_in, kernel_size, kernel_size)
        else:
            k = torch.zeros(card, card, card, kernel_size, kernel_size)
            for i, a in enumerate(generate_circulant(card)):
                k[i, :, :, origin, origin] = a
            k = torch.where(k < essentially_zero, torch.tensor(0.), k)
    else:
        raise NotImplementedError()
    if len(bs) > 0:
        bs = torch.cat([bs, k])
    else:
        bs = k

    return bs


def check_equivariance(r2_act, in_type, out_type, data, func, data_type=None):
    data = GeometricTensor(data, in_type)
    for g in r2_act.testing_elements:

        output = func(data)
        rg_output = output.transform(g)

        x_transformed = data.transform(g)
        output_rg = func(x_transformed)

        eps = rg_output.tensor.squeeze() - output_rg.tensor.squeeze()
        worst = eps.abs().max()
        assert torch.allclose(eps, torch.tensor(0.), atol=1e-5), (g, worst)
    print("Passed Equivariance Test")


class OursR2Conv(EquivariantModule):
    """NOTE: Most of the code is copied from nn.R2Conv module,
       but we adapt to accomodate our custom basis.
    """

    def __init__(self,
                 in_type: FieldType,
                 out_type: FieldType,
                 kernel_size: int,
                 padding: int = 0,
                 stride: int = 1,
                 dilation: int = 1,
                 padding_mode: str = 'zeros',
                 groups: int = 1,
                 bias: bool = True,
                 basis_zero: float = 1e-3,
                 basis_upfactor: int = 3,
                 basis_margin_radius: float = 0.,
                 basis_decay_radius: float = 1.,
                 basis_margin_arc_edges: float = 0.5,
                 basis_decay_arc_edges: float = 0.75,
                 initialize: bool = True,
                 ):
        r"""


        G-steerable planar convolution mapping between the input and output :class:`~e2cnn.nn.FieldType` s specified by
        the parameters ``in_type`` and ``out_type``.
        This operation is equivariant under the action of :math:`\R^2\rtimes G` where :math:`G` is the
        :attr:`e2cnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``.

        Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and
        :math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output
        field types.
        Then :class:`~e2cnn.nn.R2Conv` guarantees an equivariant mapping

        .. math::
            \kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2

        where the transformation of the input and output fields are given by

        .. math::
            [\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\
            [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\

        The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an
        equivariant subspace.
        As proven in `3D Steerable CNNs <https://arxiv.org/abs/1807.02547>`_, this parametrizes the *most general
        equivariant convolutional map* between the input and output fields.
        For feature fields on :math:`\R^2` (e.g. images), the complete G-steerable kernel spaces for :math:`G \leq \O2`
        is derived in `General E(2)-Equivariant Steerable CNNs <https://arxiv.org/abs/1911.08251>`_.

        During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights
        before calling :func:`torch.nn.functional.conv2d`.
        When :meth:`~torch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored
        for future reuse such that no overhead of expanding the kernel remains.

        .. warning ::

            When :meth:`~torch.nn.Module.train()` is called, the attributes :attr:`~e2cnn.nn.R2Conv.filter` and
            :attr:`~e2cnn.nn.R2Conv.expanded_bias` are discarded to avoid situations of mismatch with the
            learnable expansion coefficients.
            See also :meth:`e2cnn.nn.R2Conv.train`.

            This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in
            a mode and lately loading it in a model with a different mode, as the attributes of the class change.
            To avoid this issue, we recommend converting the model to eval mode before storing or loading the state
            dictionary.


        The learnable expansion coefficients of the this module can be initialized with the methods in
        :mod:`e2cnn.nn.init`.
        By default, the weights are initialized in the constructors using :func:`~e2cnn.nn.init.generalized_he_init`.

        .. warning ::

            This initialization procedure can be extremely slow for wide layers.
            In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model)
            or another initialization method is preferred (e.g. :func:`~e2cnn.nn.init.deltaorthonormal_init`), the
            parameter ``initialize`` can be set to ``False`` to avoid unnecessary overhead.


        Args:
            in_type (FieldType): the type of the input field, specifying its transformation law
            out_type (FieldType): the type of the output field, specifying its transformation law
            kernel_size (int): the size of the (square) filter
            padding (int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            padding_mode(str, optional): ``zeros``, ``reflect``, ``replicate`` or ``circular``. Default: ``zeros``
            stride (int, optional): the stride of the kernel. Default: ``1``
            dilation (int, optional): the spacing between kernel elements. Default: ``1``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    It allows depthwise convolution. When used, the input and output types need to be
                                    divisible in ``groups`` groups, all equal to each other.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            initialize (bool, optional): initialize the weights of the model. Default: ``True``

        Attributes:

            ~.weights (torch.Tensor): the learnable parameters which are used to expand the kernel
            ~.filter (torch.Tensor): the convolutional kernel obtained by expanding the parameters
                                    in :attr:`~e2cnn.nn.R2Conv.weights`
            ~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True``
            ~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding
                                    the parameters in :attr:`~e2cnn.nn.R2Conv.bias`

        """

        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(OursR2Conv, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.out_type = out_type

        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.padding_mode = padding_mode
        self.groups = groups

        if isinstance(padding, tuple) and len(padding) == 2:
            _padding = padding
        elif isinstance(padding, int):
            _padding = (padding, padding)
        else:
            raise ValueError('padding needs to be either an integer or a tuple containing two integers but {} found'.format(padding))

        padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in padding_modes:
            raise ValueError("padding_mode must be one of [{}], but got padding_mode='{}'".format(padding_modes, padding_mode))
        self._reversed_padding_repeated_twice = tuple(x for x in reversed(_padding) for _ in range(2))

        if groups > 1:
            # Check the input and output classes can be split in `groups` groups, all equal to each other
            # first, check that the number of fields is divisible by `groups`
            assert len(in_type) % groups == 0
            assert len(out_type) % groups == 0
            in_size = len(in_type) // groups
            out_size = len(out_type) // groups

            # then, check that all groups are equal to each other, i.e. have the same types in the same order
            assert all(in_type.representations[i] == in_type.representations[i % in_size] for i in range(len(in_type)))
            assert all(out_type.representations[i] == out_type.representations[i % out_size] for i in range(len(out_type)))

            # finally, retrieve the type associated to a single group in input.
            # this type will be used to build a smaller kernel basis and a smaller filter
            # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from one input
            # group to all output groups. Then, PyTorch's standard convolution routine interpret this filter as `groups`
            # different filters, each mapping an input group to an output group.
            in_type = in_type.index_select(list(range(in_size)))

        if bias:
            # bias can be applied only to trivial irreps inside the representation
            # to apply bias to a field we learn a bias for each trivial irreps it contains
            # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
            # this is equivalent to transform the field to its irreps through the inverse change of basis,
            # sum the bias only to the trivial irrep and then map it back with the change of basis

            # count the number of trivial irreps
            trivials = 0
            for r in self.out_type:
                for irr in r.irreps:
                    if self.out_type.fibergroup.irreps[irr].is_trivial():
                        trivials += 1

            # if there is at least 1 trivial irrep
            if trivials > 0:

                # matrix containing the columns of the change of basis which map from the trivial irreps to the
                # field representations. This matrix allows us to map the bias defined only over the trivial irreps
                # to a bias for the whole field more efficiently
                bias_expansion = torch.zeros(self.out_type.size, trivials)

                p, c = 0, 0
                for r in self.out_type:
                    pi = 0
                    for irr in r.irreps:
                        irr = self.out_type.fibergroup.irreps[irr]
                        if irr.is_trivial():
                            bias_expansion[p:p+r.size, c] = torch.tensor(r.change_of_basis[:, pi])
                            c += 1
                        pi += irr.size
                    p += r.size

                self.register_buffer("bias_expansion", bias_expansion)
                self.bias = Parameter(torch.zeros(trivials), requires_grad=True)
                self.register_buffer("expanded_bias", torch.zeros(out_type.size))
            else:
                self.bias = None
                self.expanded_bias = None
        else:
            self.bias = None
            self.expanded_bias = None

        assert(kernel_size % 2 == 1)
        self.register_buffer("basis", build_basis(kernel_size, self.space,
                                                  in_type, out_type,
                                                  essentially_zero=basis_zero,
                                                  upfactor=basis_upfactor,
                                                  margin_radius=basis_margin_radius,
                                                  decay_radius=basis_decay_radius,
                                                  margin_arc_edges=basis_margin_arc_edges,
                                                  decay_arc_edges=basis_decay_arc_edges), persistent=True)
        self.weights = Parameter(torch.empty(self.basis.size(0), len(out_type), len(in_type)), requires_grad=True)
        self.register_buffer("filter", torch.zeros(out_type.size, in_type.size, kernel_size, kernel_size), persistent=False)

        if initialize:
            # by default, the weights are initialized with a generalized form of He's weight initialization
            #  init.generalized_he_init(self.weights.data, self.basisexpansion)
            #  std = 1. / (self.space.fibergroup.order() * math.sqrt(in_type.size))
            std = 1. / math.sqrt(in_type.size)
            torch.nn.init.normal_(self.weights, mean=0.0, std=std)

    def expand_parameters(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""

        Expand the filter in terms of the :attr:`e2cnn.nn.R2Conv.weights` and the
        expanded bias in terms of :class:`e2cnn.nn.R2Conv.bias`.

        Returns:
            the expanded filter and bias

        """
        _filter = torch.einsum('bijhw,bmn->minjhw', self.basis, self.weights).contiguous()
        _filter = einops.rearrange(_filter, 'm i n j h w -> (m i) (n j) h w')
        #  _filter = _filter.reshape(self.out_type.size, self.in_type.size, self.kernel_size, self.kernel_size)
        #  import pdb; pdb.set_trace()

        if self.bias is None:
            _bias = None
        else:
            _bias = self.bias_expansion @ self.bias

        return _filter, _bias

    def forward(self, input: GeometricTensor):
        r"""
        Convolve the input with the expanded filter and bias.

        Args:
            input (GeometricTensor): input feature field transforming according to ``in_type``

        Returns:
            output feature field transforming according to ``out_type``

        """

        assert input.type == self.in_type

        if not self.training:
            _filter = self.filter
            _bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            _filter, _bias = self.expand_parameters()

        # use it for convolution and return the result

        if self.padding_mode == 'zeros':
            output = conv2d(input.tensor, _filter,
                            stride=self.stride,
                            padding=self.padding,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)
        else:
            output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
                            _filter,
                            stride=self.stride,
                            dilation=self.dilation,
                            padding=(0,0),
                            groups=self.groups,
                            bias=_bias)

        return GeometricTensor(output, self.out_type)

    def train(self, mode=True):
        r"""

        If ``mode=True``, the method sets the module in training mode and discards the :attr:`~e2cnn.nn.R2Conv.filter`
        and :attr:`~e2cnn.nn.R2Conv.expanded_bias` attributes.

        If ``mode=False``, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using
        the current values of the trainable parameters and store them in :attr:`~e2cnn.nn.R2Conv.filter` and
        :attr:`~e2cnn.nn.R2Conv.expanded_bias` such that they are not recomputed at each forward pass.

        .. warning ::

            This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in
            a mode and lately loading it in a model with a different mode, as the attributes of this class change.
            To avoid this issue, we recommend converting the model to eval mode before storing or loading the state
            dictionary.

        Args:
            mode (bool, optional): whether to set training mode (``True``) or evaluation mode (``False``).
                                   Default: ``True``.

        """

        if mode:
            # TODO thoroughly check this is not causing problems
            if hasattr(self, "filter"):
                del self.filter
            if hasattr(self, "expanded_bias"):
                del self.expanded_bias
        elif self.training:
            # avoid re-computation of the filter and the bias on multiple consecutive calls of `.eval()`

            _filter, _bias = self.expand_parameters()

            self.register_buffer("filter", _filter, persistent=False)
            if _bias is not None:
                self.register_buffer("expanded_bias", _bias)
            else:
                self.expanded_bias = None

        return super(OursR2Conv, self).train(mode)

    def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
        assert len(input_shape) == 4
        assert input_shape[1] == self.in_type.size

        b, c, hi, wi = input_shape

        ho = math.floor((hi + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) / self.stride + 1)
        wo = math.floor((wi + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) / self.stride + 1)

        return b, self.out_type.size, ho, wo

    def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1, assertion: bool = True, verbose: bool = True):

        # np.set_printoptions(precision=5, threshold=30 *self.in_type.size**2, suppress=False, linewidth=30 *self.in_type.size**2)

        feature_map_size = 33
        last_downsampling = 5
        first_downsampling = 5

        initial_size = (feature_map_size * last_downsampling - 1 + self.kernel_size) * first_downsampling

        c = self.in_type.size

        import matplotlib.image as mpimg
        from skimage.measure import block_reduce
        from skimage.transform import resize

        x = mpimg.imread('../group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :]

        x = resize(
            x,
            (x.shape[0], x.shape[1], initial_size, initial_size),
            anti_aliasing=True
        )

        x = x / 255.0 - 0.5

        if x.shape[1] < c:
            to_stack = [x for i in range(c // x.shape[1])]
            if c % x.shape[1] > 0:
                to_stack += [x[:, :(c % x.shape[1]), ...]]

            x = np.concatenate(to_stack, axis=1)

        x = GeometricTensor(torch.FloatTensor(x), self.in_type)

        def shrink(t: GeometricTensor, s) -> GeometricTensor:
            return GeometricTensor(torch.FloatTensor(block_reduce(t.tensor.detach().numpy(), s, func=np.mean)), t.type)

        errors = []

        for el in self.space.testing_elements:

            out1 = self(shrink(x, (1, 1, 5, 5))).transform(el).tensor.detach().numpy()
            out2 = self(shrink(x.transform(el), (1, 1, 5, 5))).tensor.detach().numpy()

            out1 = block_reduce(out1, (1, 1, 5, 5), func=np.mean)
            out2 = block_reduce(out2, (1, 1, 5, 5), func=np.mean)

            b, c, h, w = out2.shape

            center_mask = np.zeros((2, h, w))
            center_mask[1, :, :] = np.arange(0, w) - w / 2
            center_mask[0, :, :] = np.arange(0, h) - h / 2
            center_mask[0, :, :] = center_mask[0, :, :].T
            center_mask = center_mask[0, :, :] ** 2 + center_mask[1, :, :] ** 2 < (h / 4) ** 2

            out1 = out1[..., center_mask]
            out2 = out2[..., center_mask]

            out1 = out1.reshape(-1)
            out2 = out2.reshape(-1)

            errs = np.abs(out1 - out2)

            esum = np.maximum(np.abs(out1), np.abs(out2))
            esum[esum == 0.0] = 1

            relerr = errs / esum

            if verbose:
                print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var())

            tol = rtol * esum + atol

            if np.any(errs > tol) and verbose:
                print(out1[errs > tol])
                print(out2[errs > tol])
                print(tol[errs > tol])

            if assertion:
                assert np.all(errs < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format(el, errs.max(), errs.mean(), errs.var())

            errors.append((el, errs.mean()))

        return errors

    def export(self):
        r"""
        Export this module to a normal PyTorch :class:`torch.nn.Conv2d` module and set to "eval" mode.

        """

        # set to eval mode so the filter and the bias are updated with the current
        # values of the weights
        self.eval()
        _filter = self.filter
        _bias = self.expanded_bias

        if self.padding_mode not in ['zeros']:
            x, y = torch.__version__.split('.')[:2]
            if int(x) < 1 or int(y) < 5:
                if self.padding_mode == 'circular':
                    raise ImportError(
                        "'{}' padding mode had some issues in old `torch` versions. Therefore, we only support conversion from version 1.5 but only version {} is installed.".format(
                            self.padding_mode, torch.__version__
                        )
                    )

                else:
                    raise ImportError(
                        "`torch` supports '{}' padding mode only from version 1.5 but only version {} is installed.".format(
                            self.padding_mode, torch.__version__
                        )
                    )

        # build the PyTorch Conv2d module
        has_bias = self.bias is not None
        conv = torch.nn.Conv2d(self.in_type.size,
                               self.out_type.size,
                               self.kernel_size,
                               padding=self.padding,
                               padding_mode=self.padding_mode,
                               stride=self.stride,
                               dilation=self.dilation,
                               groups=self.groups,
                               bias=has_bias)

        # set the filter and the bias
        conv.weight.data = _filter.data
        if has_bias:
            conv.bias.data = _bias.data

        return conv

    def __repr__(self):
        extra_lines = []
        extra_repr = self.extra_repr()
        if extra_repr:
            extra_lines = extra_repr.split('\n')

        main_str = self._get_name() + '('
        if len(extra_lines) == 1:
            main_str += extra_lines[0]
        else:
            main_str += '\n  ' + '\n  '.join(extra_lines) + '\n'

        main_str += ')'
        return main_str

    def extra_repr(self):
        s = ('OURS|{in_type}, {out_type}, kernel_size={kernel_size}, stride={stride}')
        if self.padding != 0 and self.padding != (0, 0):
            s += ', padding={padding}'
        if self.dilation != 1 and self.dilation != (1, 1):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        return s.format(**self.__dict__)
