
# © 2021 Copyright claimant to remain anonymous during evaluation period. All rights reserved. May be used only pursuant to Software Evaluation Terms of Use.  CONFIDENTIAL – MAY CONTAIN TRADE SECRETS


import numpy as np

from .basis import KernelBasis
from .steerable_basis import SteerableKernelBasis, IrrepBasis
from .spaces import SpaceIsomorphism
from .wignereckart_solver import WignerEckartBasis

from package.group import Representation

from typing import Type, Union, Tuple, Dict, List, Iterable, Callable, Set


class SparseSteerableBasis(KernelBasis):
    
    def __init__(self,
                 X: SpaceIsomorphism,
                 in_repr: Representation,
                 out_repr: Representation,
                 sigma: float,
                 harmonics: List = None,
                 change_of_basis: np.ndarray = None,
                 attributes: Dict = None):
        r"""

        Class which implements a steerable basis for a discrete homogeneous space ```X``` by by embedding the finite
        number of points of ```X``` into :math:`\R^n` and then "diffusing" the kernel basis defined over these
        points in the ambient space using a Gaussian kernel.

        This class only supports finite homogeneous spaces generated by a finite symmetry group, i.e. not only
        :math:`X=G/H` must have a finite number of elements but also :math:`G` itself should.

        Args:
            X:
            in_repr:
            out_repr:
            sigma:
            harmonics:
            change_of_basis:

        """

        self.X = X
        assert in_repr.group == out_repr.group
        self.group = in_repr.group
        self.in_repr = in_repr
        self.out_repr = out_repr

        # check that the homogeneous space has a finite number of elements
        # unfortunately, we can not directly check the size of the homogeneous space G/H so we check that G itself has
        # only a finite number of elements
        assert self.X.G.order() > 0

        points = np.concatenate([
            self.X.projection(g).reshape(-1, 1) for g in self.X.G.elements
        ], axis=1)

        _, idx = np.unique(points.round(decimals=4), axis=1, return_index=True)
        points = points[:, idx]

        assert points.shape == (self.X.dim, self.X.G.order() / self.X.H.order())

        assert sigma > 0., sigma

        self.points = points
        self.sigma = sigma

        self.wigner_basis = SteerableKernelBasis(self.X, in_repr, out_repr, WignerEckartBasis, harmonics=harmonics)

        self._harmonics = self.wigner_basis.compute_harmonics(self.points)

        dim = len(self.wigner_basis)

        super(SparseSteerableBasis, self).__init__(dim, (out_repr.size, in_repr.size))

        if change_of_basis is not None:
            assert change_of_basis.shape == (self.X.dim, self.X.dim)
            # check that the matrix is invertible
            assert np.isfinite(np.linalg.cond(change_of_basis))

            self.points = change_of_basis @ self.points

        self.change_of_basis = change_of_basis

        self._attributes = attributes if attributes is not None else dict()

    def sample(self, points: np.ndarray, out: np.ndarray = None) -> np.ndarray:
        r"""

        Sample the continuous basis elements on the discrete set of points in ``points``.
        Optionally, store the resulting multidimentional array in ``out``.

        ``points`` must be an array of shape `(d, N)`, where `N` is the number of points.

        Args:
            points (~numpy.ndarray): points where to evaluate the basis elements
            out (~numpy.ndarray, optional): pre-existing array to use to store the output

        Returns:
            the sampled basis

        """
    
        assert len(points.shape) == 2
        S = points.shape[1]
        assert points.shape[0] == self.X.dim
    
        if out is None:
            out = np.empty((self.shape[0], self.shape[1], self.dim, S))
    
        assert out.shape == (self.shape[0], self.shape[1], self.dim, S)

        weights = np.expand_dims(self.points, 2) - np.expand_dims(points, 1)
        assert weights.shape == (self.X.dim, self.points.shape[1], S)

        weights = (weights**2).sum(axis=0) / self.sigma**2
        weights = np.exp(- 0.5 * weights)

        B = 0
        harmonics = {}
        outs = {}
        for b, j in enumerate(self.wigner_basis.js):
    
            harmonics[j] = np.einsum('rmi,io->rmo', self._harmonics[j], weights)
            
            outs[j] = out[:, :, B:B + self.dim_harmonic(j), :]
            B += self.dim_harmonic(j)

        self.sample_harmonics(harmonics, outs)
        return out
    
    def sample_harmonics(self, points: Dict[Tuple, np.ndarray], out: Dict[Tuple, np.ndarray] = None) -> Dict[Tuple, np.ndarray]:
        return self.wigner_basis.sample_harmonics(points, out)

    def dim_harmonic(self, j: Tuple) -> int:
        return self.wigner_basis.dim_harmonic(j)
    
    def attrs_j_iter(self, j: Tuple) -> Iterable:
        for attr in self.wigner_basis.attrs_j_iter(j):
            attr['sigma'] = self.sigma
            attr.update(**self._attributes)
            yield attr

    def attrs_j(self, j: Tuple, idx) -> Dict:
        attr = self.wigner_basis.attrs_j(j, idx)

        attr['sigma'] = self.sigma
        attr.update(**self._attributes)
        return attr

    def __getitem__(self, idx):
        attr = self.wigner_basis[idx]

        attr['sigma'] = self.sigma
        attr.update(**self._attributes)
        return attr

    def __iter__(self):
        for attr in self.wigner_basis:
            attr['sigma'] = self.sigma
            attr.update(**self._attributes)
            yield attr

    def __eq__(self, other):
        if not isinstance(other, SparseSteerableBasis):
            return False
        elif self.wigner_basis != other.wigner_basis:
            return False
        elif not np.isclose(self.sigma, other.sigma):
            return False
        elif (self.change_of_basis is None) != (other.change_of_basis is None):
            return False
        elif self.change_of_basis is not None:
            return np.allclose(self.change_of_basis, other.change_of_basis)
        else:
            return True

    def __hash__(self):
        return 100 * hash(self.wigner_basis) * 10 + hash(self.sigma) + hash(str(self.change_of_basis))

