
# © 2021 Copyright claimant to remain anonymous during evaluation period. All rights reserved. May be used only pursuant to Software Evaluation Terms of Use.  CONFIDENTIAL – MAY CONTAIN TRADE SECRETS


import torch

from package.gspaces import *
from .equivariant_module import EquivariantModule
from package.nn import GeometricTensor
from package.nn import FieldType

from typing import List, Tuple, Any

import numpy as np

__all__ = ["ReshuffleModule"]


class ReshuffleModule(EquivariantModule):
    
    def __init__(self,
                 in_type: FieldType,
                 permutation: List[int]
                 ):
        r"""
        
        Permutes the fields of the input tensor according to the input ``permutation``.
        
        The parameter ``permutation`` should be a list containing a permutation of the integers ``{0, 1, ..., n-1}``,
        where ``n`` is the number of fields of ``in_type`` (see :meth:`package.nn.FieldType.__len__`).
        
        
        Args:
            in_type (FieldType): input field type
            permutation (list): permutation to apply
            
        """
    
        assert isinstance(in_type.gspace, GSpace)
        
        super(ReshuffleModule, self).__init__()
        
        # check if it is actually a permutation of the fields
        assert set(permutation) == set(range(0, len(in_type.representations)))
        
        self.space = in_type.gspace
        self.in_type = in_type
        
        # permute the fields in the input representation to build the output representation
        self.out_type = in_type.index_select(permutation)
        
        # compute the starting position of each field in the input representation
        positions = []
        last_position = 0
        for r in in_type.representations:
            positions.append(last_position)
            last_position += r.size

        # compute the indices for the permutation
        
        indices = []
        for c in permutation:
            size = in_type.representations[c].size
            
            indices += list(range(positions[c], positions[c]+size))
        
        self.register_buffer("indices", torch.LongTensor(indices))
        
    def forward(self, input: GeometricTensor):
        
        assert input.type == self.in_type
        
        # retrieve the values from the input using the permutation of the indices computed before
        data = input.tensor[:, self.indices, ...]
        
        return GeometricTensor(data, self.out_type, input.coords)

    def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]:

        assert len(input_shape) >= 2
        assert input_shape[1] == self.in_type.size

        b, c = input_shape[:2]
        spatial_shape = input_shape[2:]

        return (b, self.out_type.size, *spatial_shape)

    def check_equivariance(self, atol: float = 1e-7, rtol: float = 1e-5) -> List[Tuple[Any, float]]:
    
        c = self.in_type.size
    
        x = torch.randn(3, c, 10, 10)
    
        x = GeometricTensor(x, self.in_type)
    
        errors = []
    
        for el in self.space.testing_elements:
            out1 = self(x).transform_fibers(el)
            out2 = self(x.transform_fibers(el))
        
            errs = (out1.tensor - out2.tensor).detach().numpy()
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())
        
            assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
                'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                    .format(el, errs.max(), errs.mean(), errs.var())
        
            errors.append((el, errs.mean()))
    
        return errors