
# © 2021 Copyright claimant to remain anonymous during evaluation period. All rights reserved. May be used only pursuant to Software Evaluation Terms of Use.  CONFIDENTIAL – MAY CONTAIN TRADE SECRETS


from package.gspaces import *
from package.group import CyclicGroup
from package.nn import FieldType
from package.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch

from typing import List, Tuple, Any

import numpy as np

__all__ = ["VectorFieldNonLinearity"]


class VectorFieldNonLinearity(EquivariantModule):
    
    def __init__(self, in_type: FieldType, **kwargs):
        r"""
        
        VectorField non-linearities.
        This non-linearity only supports the regular representation of cyclic group :math:`C_N`, i.e. the group of
        :math:`N` discrete rotations.
        For each input field, the output one is built by taking the rotation associated with the highest
        activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a
        length equal to its activation is set in the output field.
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type.gspace, GSpace)

        assert isinstance(in_type.gspace.fibergroup, CyclicGroup)
        assert in_type.gspace.fibergroup.order() > 1
        
        for r in in_type.representations:
            assert 'vectorfield' in r.supported_nonlinearities,\
                'Error! Representation "{}" does not support "vector-field" non-linearity'.format(r.name)
            
            assert r.name == 'regular' and r.size == in_type.gspace.fibergroup.order(), r.name

        super(VectorFieldNonLinearity, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        
        # build the output representation substituting each input field with a rotation representation with frequency 1
        self.out_type = FieldType(self.space, [self.space.representations['irrep_1']] * len(in_type))
        
        # the number of rotations associated with the group action
        self._rotations = self.space.fibergroup.order()
        
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Apply the VectorField non-linearity to the input feature map.
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """
        
        assert input.type == self.in_type

        b, c = input.shape[:2]
        spatial_shape = input.shape[2:]

        # split the channel dimension in 2 dimensions, separating fields
        fm = input.tensor.view(b, -1, self._rotations, *spatial_shape)
        
        # evaluate the base rotation associated with the group action
        base_angle = 2 * np.pi / self._rotations
        
        # for each field, retrieve the maximum activation (and the argmax)
        max_activations, argmaxes = torch.max(fm, 2)
        max_activations = torch.relu_(max_activations)
        
        # compute the angles from the index of the maximum activation in the field
        max_angles = argmaxes.to(dtype=torch.float) * base_angle
        
        # build the output tensor
        output = torch.empty(b, self.out_type.size, *spatial_shape, dtype=torch.float, device=input.tensor.device)
        
        # to build the output vectors, take the cosine and the sine of the argmax angle
        # and multiply the 2-dimensional vector by the activation value
        output[:, ::2, ...] = torch.cos(max_angles) * max_activations
        output[:, 1::2, ...] = torch.sin(max_angles) * max_activations
        
        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type, input.coords)

    def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]:

        assert len(input_shape) >= 2
        assert input_shape[1] == self.in_type.size

        b, c = input_shape[:2]
        spatial_shape = input_shape[2:]

        return (b, self.out_type.size, *spatial_shape)

    def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]:
    
        c = self.in_type.size
    
        x = torch.randn(3, c, 10, 10)
    
        x = GeometricTensor(x, self.in_type)
    
        errors = []
    
        for el in self.space.testing_elements:
            out1 = self(x).transform_fibers(el)
            out2 = self(x.transform_fibers(el))
        
            errs = (out1.tensor - out2.tensor).detach().numpy()
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())
        
            assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
                'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                    .format(el, errs.max(), errs.mean(), errs.var())
        
            errors.append((el, errs.mean()))
    
        return errors


