
# © 2021 Copyright claimant to remain anonymous during evaluation period. All rights reserved. May be used only pursuant to Software Evaluation Terms of Use.  CONFIDENTIAL – MAY CONTAIN TRADE SECRETS


from package.nn import FieldType

from package.gspaces import GSpace3D
from package.group import Representation, Group, Icosahedral
from package.kernels import KernelBasis
from package.kernels import kernels_aliased_Ico_act_R3_icosahedron, kernels_aliased_Ico_act_R3_icosidodecahedron, kernels_aliased_Ico_act_R3_dodecahedron

from .r3convolution import R3Conv

from typing import Callable, Union, List


__all__ = ["R3Conv"]


class R3IcoConv(R3Conv):
    
    def __init__(self,
                 in_type: FieldType,
                 out_type: FieldType,
                 kernel_size: int,
                 padding: int = 0,
                 stride: int = 1,
                 dilation: int = 1,
                 padding_mode: str = 'zeros',
                 groups: int = 1,
                 bias: bool = True,
                 samples: str = 'ico',
                 sigma: Union[List[float], float] = None,
                 rings: List[float] = None,
                 recompute: bool = False,
                 basis_filter: Callable[[dict], bool] = None,
                 initialize: bool = True,
                 ):
        r"""


        Icosahedral-steerable volumetric convolution mapping between the input and output :class:`~package.nn.FieldType` s specified by
        the parameters ``in_type`` and ``out_type``.
        This operation is equivariant under the action of :math:`\R^3\rtimes I` where :math:`I`
        (the :class:`~package.group.Icosahedral` group) is the :attr:`package.nn.FieldType.fibergroup` of
        ``in_type`` and ``out_type``.

        This class is mostly similar to :class:`~package.nn.R3Conv`, with the only difference that it only supports the
        group :class:`~package.group.Icosahedral` since it uses a kernel basis which is specific for this group.

        The argument ```frequencies_cutoff``` of :class:`~package.nn.R3Conv` is not supported here since the steerable
        kernels are not generated from a band-limited set of harmonic functions.

        Instead, the argument ```samples``` specifies the polyhedron (symmetric with respect to the
        :class:`~package.group.Icosahedral`  group) whose vertices are used to define the kernel on :math:`\R^3`.
        The supported polyhedrons are ```"ico"``` (the 12 vertices of the icosahedron), ```"dodeca"``` (the 20 vertices
        of the dodecahedron) or ```"icosidodeca"``` (the 30 vertices of the icosidodecahedron, which correspond to the
        centers of the 30 edges of either the icosahedron or the dodecahedron).

        For each ring ```r``` in ```rings```, the polyhedron specified in embedded in the sphere of radius ```r```.
        The analytical kernel, which is only defined on the vertices of this polyhedron, is then "diffused" in the
        ambient space :math:`\R^3` by means of a small Gaussian kernel with std ```sigma```.

        .. warning ::
            Even if the input tensor has a `coords` attribute, the output of this module will not have one.

        Args:
            in_type (FieldType): the type of the input field, specifying its transformation law
            out_type (FieldType): the type of the output field, specifying its transformation law
            kernel_size (int): the size of the (square) filter
            padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            stride(int, optional): the stride of the kernel. Default: ``1``
            dilation(int, optional): the spacing between kernel elements. Default: ``1``
            padding_mode(str, optional): ``zeros``, ``reflect``, ``replicate`` or ``circular``. Default: ``zeros``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    It allows depthwise convolution. When used, the input and output types need to be
                                    divisible in ``groups`` groups, all equal to each other.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            sigma (list or float, optional): width of each ring where the bases are sampled. If only one scalar
                    is passed, it is used for all rings.
            rings (list, optional): radii of the rings where to sample the bases
            recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels.
                    By Default (``False``), it  caches the basis built or reuse a cached one, if it is found.
            basis_filter (callable, optional): function which takes as input a descriptor of a basis element
                    (as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``)
                    the basis element. By default (``None``), no filtering is applied.
            initialize (bool, optional): initialize the weights of the model. Default: ``True``

        """

        assert isinstance(in_type.gspace, GSpace3D)
        assert isinstance(out_type.gspace, GSpace3D)
        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace.fibergroup, Icosahedral)

        assert samples in ['ico', 'dodeca', 'icosidodeca']
        self._samples = samples

        super(R3IcoConv, self).__init__(
            in_type,
            out_type,
            kernel_size,
            padding,
            stride,
            dilation,
            padding_mode,
            groups,
            bias,
            sigma,
            frequencies_cutoff=5., # to avoid performing any frequency cut-off, set a large upperbound
            rings=rings,
            recompute=recompute,
            basis_filter=basis_filter,
            initialize=initialize,
        )

    def _build_kernel_basis(self, in_repr: Representation, out_repr: Representation) -> KernelBasis:
        if self._samples == 'ico':
            return kernels_aliased_Ico_act_R3_icosahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings)
        if self._samples == 'dodeca':
            return kernels_aliased_Ico_act_R3_dodecahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings)
        if self._samples == 'icosidodeca':
            return kernels_aliased_Ico_act_R3_icosidodecahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings)
        else:
            raise ValueError


