
# © 2021 Copyright claimant to remain anonymous during evaluation period. All rights reserved. May be used only pursuant to Software Evaluation Terms of Use.  CONFIDENTIAL – MAY CONTAIN TRADE SECRETS

from __future__ import annotations

from package.group import change_basis, directsum
from package.group.irrep import generate_irrep_matrices_from_generators
from package.group.irrep import restrict_irrep
from package.group.utils import cycle_isclose

from .utils import *

from .so3_utils import PARAMETRIZATION as PARAMETRIZATION_SO3
from .so3_utils import PARAMETRIZATIONS
from .so3_utils import IDENTITY, _grid, _combine, _equal, _invert, _change_param, _check_param, _hash

from .so3group import _build_character, _build_irrep

import numpy as np

from typing import Tuple, Callable, Iterable, List, Dict, Any, Union


__all__ = ["Octahedral"]

_PHI = (1. + np.sqrt(5)) / 2


class Octahedral(Group):

    PARAM = PARAMETRIZATION_SO3
    
    PARAMETRIZATIONS = PARAMETRIZATIONS
    
    def __init__(self):
        r"""

        Subgroup Structure:

        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
        |    ``id[0]``                      |    ``id[1]``                      |    subgroup                                                                                                       |
        +===================================+===================================+===================================================================================================================+
        |        'octa'                     |                                   |   The Octahedral :math:`O` group itself                                                                           |
        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
        |        'tetra'                    |                                   |   Tetrahedral :math:`T` subgroup                                                                                  |
        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
        |        False                      |     N = 1, 2, 3, 4                |   :math:`C_N` of N discrete planar rotations                                                                      |
        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
        |        True                       |     N = 2, 3, 4                   |   *dihedral* :math:`D_N` subgroup of N discrete planar rotations and out-of-plane :math:`\pi` rotation            |
        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
        |        True                       |     1                             |   equivalent to ``(False, 2, adj)``                                                                               |
        +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+


        """
        
        super(Octahedral, self).__init__("Octahedral", False, False)
        
        self._identity = self.element(IDENTITY)
        
        self._elements = [self.element(g) for g in _grid('cube')]
        assert len(self._elements) == 24
            
        # self._identity = self._elements[3]

        self._generators = [
            self._elements[17],  # Cyclic Group of order 4
            self._elements[11],  # Cyclic Group of order 3
            self._elements[22],  # Cyclic Group of order 2
        ]

        self._build_representations()

    @property
    def generators(self) -> List[GroupElement]:
        return self._generators

    @property
    def identity(self) -> GroupElement:
        return self._identity

    @property
    def elements(self) -> List[GroupElement]:
        return self._elements
     
    @property
    def _keys(self) -> Dict[str, Any]:
        return dict()

    @property
    def subgroup_trivial_id(self):
        raise NotImplementedError

    @property
    def subgroup_self_id(self):
        raise NotImplementedError
        return 'octa'

    ###########################################################################
    # METHODS DEFINING THE GROUP LAW AND THE OPERATIONS ON THE GROUP'S ELEMENTS
    ###########################################################################

    def _inverse(self, element, param=PARAM):
        r"""
        Return the inverse element of the input element
        """
        return _invert(element, param=param)

    def _combine(self, e1, e2,
                param=PARAM,
                param1=None,
                param2=None
                ):
        r"""
        Return the sum of the two input elements
        """
        return _combine(e1, e2, param=param, param1=param1, param2=param2)

    def _equal(self, e1, e2,
              param=PARAM,
              param1=None,
              param2=None,
              ) -> bool:
        r"""
        Check if the two input values corresponds to the same element
        """
        return _equal(e1, e2, param=param, param1=param1, param2=param2)

    def _hash_element(self, element, param=PARAM):
        return _hash(element, param)

    def _repr_element(self, element, param=PARAM):
        return element.__repr__()

    def _is_element(self, element,
                    param: str = PARAM,
                    verbose: bool = False,
                    ) -> bool:

        ATOL = 1e-7
        RTOL = 1e-5

        if not _check_param(element, param):
            if verbose:
                print(f"Element {element} is not a rotation")
            return False

        # convert to matrix representation
        element = self._change_param(element, param, 'MAT')

        # take absolute value of the elements
        # note that we have already ensured that the determinant is positive using `_check_param` above since it checks
        # that it is a rotation
        at = np.abs(element)

        # check if the matrix is a permutation matrix
        ans = (
                np.isclose(at.sum(axis=0), 1., atol=ATOL, rtol=RTOL).all()
            and np.isclose(at.sum(axis=1), 1., atol=ATOL, rtol=RTOL).all()
            and (np.isclose(at, 1., atol=ATOL, rtol=RTOL) | np.isclose(at, 0., atol=ATOL, rtol=RTOL)).all()
        )

        return ans

    def _change_param(self, element, p_from: str, p_to: str):
        assert p_from in self.PARAMETRIZATIONS
        assert p_to in self.PARAMETRIZATIONS
        return _change_param(element, p_from, p_to)

    ###########################################################################

    def sample(self, param: str = PARAM) -> GroupElement:
        return self._elements[
            np.random.randint(self.order())
        ]

    def testing_elements(self) -> Iterable[GroupElement]:
    
        r"""
        A finite number of group elements to use for testing.
        """
        return iter(self._elements)

    def __eq__(self, other):
        if not isinstance(other, Octahedral):
            return False
        else:
            return self.name == other.name

    def _process_subgroup_id(self, id):

        if not isinstance(id, tuple):
            id = (id,)

        assert isinstance(id[0], bool) or isinstance(id[0], str), id[0]

        if not isinstance(id[-1], GroupElement):
            id = (*id, self.identity)

        assert id[-1].group == self

        if isinstance(id[0], bool):
            assert id[1] in [1, 2, 3, 4]

            if id[0] == True and id[1] == 1:
                # flip subgroup of the O(2) subgroup of SO(3)
                # this is equivalent to the C_2 subgroup of 180 deg rotations out of the plane (around X axis)

                V = np.array([1., 1., -1.])
                V /= np.linalg.norm(V)

                change_axis = np.zeros(4)
                change_axis[:3] = V * np.sin(np.pi/3.)
                change_axis[3] = np.cos(np.pi/3.)

                adj = self.element(change_axis, 'Q') @ id[-1]
                id = (False, 2, adj)

        return id

    def _subgroup(self, id) -> Tuple[
        Group,
        Callable[[GroupElement], GroupElement],
        Callable[[GroupElement], GroupElement]
    ]:
        r"""

        Returns:
            a tuple containing

                - the subgroup,

                - a function which maps an element of the subgroup to its inclusion in the original group and

                - a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)

        """

        sg = None
        parent_map = None
        child_map = None

        id, adj = id[:-1], id[-1]

        if id == ('octa',):
            sg = self
            parent_map = build_adjoint_map(self, ~adj)
            child_map = build_adjoint_map(self, adj)
        elif id == ('tetra',):
            raise NotImplementedError()
        elif id == (False, 1):
            sg = package.group.cyclic_group(1)
            parent_map, child_map = build_trivial_subgroup_maps(self)
        elif id == (False, 2):
            sg = package.group.cyclic_group(2)
            axis = np.asarray([0., 0., 1.])
            parent_map = cn_to_octa(adj, sg, axis=axis)
            child_map = octa_to_cn(adj, sg, axis=axis)
        elif id == (False, 3):
            sg = package.group.cyclic_group(3)
            axis = np.asarray([1., 1., 1.]) / np.sqrt(3)
            parent_map = cn_to_octa(adj, sg, axis=axis)
            child_map = octa_to_cn(adj, sg, axis=axis)
        elif id == (False, 4):
            sg = package.group.cyclic_group(4)
            axis = np.asarray([0., 0., 1.])
            axis /= np.linalg.norm(axis)
            parent_map = cn_to_octa(adj, sg, axis=axis)
            child_map = octa_to_cn(adj, sg, axis=axis)
        elif id == (True, 2):
            sg = package.group.dihedral_group(2)
            parent_map, child_map = None, None
            raise NotImplementedError()
        elif id == (True, 3):
            sg = package.group.dihedral_group(3)
            parent_map, child_map = None, None
            raise NotImplementedError()
        elif id == (True, 4):
            sg = package.group.dihedral_group(4)
            parent_map, child_map = None, None
            raise NotImplementedError()
        else:
            raise ValueError(f'Subgroup id {id} not recognized!')

        return sg, parent_map, child_map

    def _restrict_irrep(self, irrep: str, id) -> Tuple[np.matrix, List[str]]:
        r"""

        Returns:
            a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep

        """
        sg_id, adj = id[:-1], id[-1]

        irr = self.irrep(*irrep)

        sg, _, _ = self.subgroup(id)

        irreps = []
        change_of_basis = None

        try:
            if sg_id == ('octa', ):
                change_of_basis = irr.change_of_basis
                irreps = irr.irreps
            elif sg_id == (False, 1):
                change_of_basis = np.eye(irr.size)
                irreps = [(1,)]*irr.size
            else:
                raise NotImplementedError()

        except NotImplementedError:
            if sg.order() > 0:
                change_of_basis, irreps = restrict_irrep(irr, sg_id)
            else:
                raise

        change_of_basis = self.irrep(*irrep)(adj).T @ change_of_basis

        return change_of_basis, irreps

    def _build_representations(self):
        r"""
        Build the irreps for this group

        """

        # Build all the Irreducible Representations

        # add Trivial representation
        self.irrep(0)

        # add other irreducible representations

        # Frequency 1 Wigner D matrix
        self.irrep(1)

        # Frequency 2 Wigner D matrix decomposes as a direct sum of a 2 and a 3 dimensional irrep
        self.irrep(-1) # 3 dimensional irrep
        self.irrep(2) # 2 dimensional irrep

        # SO(3)'s freq 3 irrep decomposes in a 1-dimensional irrep and the sum of the two previous 3 dimensional irreps
        self.irrep(3) # 1 dimensional

        # add all the irreps to the set of representations already built for this group
        self.representations.update(**{irr.name: irr for irr in self.irreps()})

        # build the regular representation

        # N.B.: it represents the LEFT-ACTION of the elements
        self.representations['regular'] = self.regular_representation

    @property
    def trivial_representation(self) -> Representation:
        return self.irrep(0)

    @property
    def standard_representation(self) -> Representation:
        r"""
        Restriction of the standard representation of SO(3) as 3x3 rotation matrices

        """
        name = f'standard'
    
        if name not in self._representations:
            change_of_basis = np.array([
                [0, 0, 1],
                [1, 0, 0],
                [0, 1, 0]
            ])
        
            self._representations[name] = change_basis(
                self.irrep(1),
                change_of_basis=change_of_basis,
                name=name,
                supported_nonlinearities=self.irrep(1).supported_nonlinearities,
            )
    
        return self._representations[name]

    @property
    def cube_vertices_representation(self) -> Representation:
        # action on the 8 vertices of the cube (or faces of the octahedron)

        sgid = (False, 3)
        return self.quotient_representation(sgid, name='cube_vertices')

    @property
    def cube_faces_representation(self) -> Representation:
        # action on the 6 faces of the cube (or vertices of the octahedron)

        sgid = (False, 4)
        return self.quotient_representation(sgid, name='cube_faces')

    @property
    def cube_edges_representation(self) -> Representation:
        # action on the 12 edges of the cube or octahedron

        sgid = (True, 1)
        return self.quotient_representation(sgid, name='cube_edges')

    def irrep(self, l: int) -> IrreducibleRepresentation:
        r"""
        Build the irrep of :math:`O` identified by the integer :math:`l`.
        For :math:`l = 0, 1`, the irrep is equivalent to the Wigner D matrix of the same frequency :math:`l`.
        For :math:`l=2`, the 5-dimensional Wigner D matrix is decomposed in a 3-dimensional and a 2-dimensional irreps,
        here identified respectively by :math:`l=-1` and :math:`l=2`.
        For :math:`l=3`, the 7-dimensional Wigner D matrix is decomposed in a 1-dimensional irrep and the two previous
        3-dimensional irreps, here identified respectively by :math:`l=3` and :math:`l=1, -1`.
        
        Args:
            l (int): identifier of the irrep

        Returns:
            the corresponding irrep

        """
        
        assert isinstance(l, int)
        assert -1 <= l <= 3

        name = f"irrep_{l}"
        id = (l,)

        if id not in self._irreps:

            if l == 0:
                # Trivial representation
                irrep = build_trivial_irrep()
                character = build_trivial_character()
                supported_nonlinearities = ['pointwise', 'norm', 'gated', 'gate']
                self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R',
                                                              supported_nonlinearities=supported_nonlinearities,
                                                              character=character,
                                                              )
            elif l == 1:
        
                # Irreducible Representation equivalent to the frequency 1 Wigner D matrices
                irrep = _build_irrep(l)
                character = _build_character(l)
                supported_nonlinearities = ['norm', 'gated']
                self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 3, 'R',
                                                              supported_nonlinearities=supported_nonlinearities,
                                                              character=character)
            elif l == -1 or l == 2:
    
                irrep = _build_octa_irrep(self, l)
                supported_nonlinearities = ['norm', 'gated']
                self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, irrep[self.identity].shape[0], 'R',
                                                             supported_nonlinearities=supported_nonlinearities)

            elif l == 3:
                irrep = _build_octa_irrep(self, l)
                supported_nonlinearities = ['norm', 'gated', 'concatenated']
                self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, irrep[self.identity].shape[0], 'R',
                                                             supported_nonlinearities=supported_nonlinearities)

            else:
                raise ValueError()

        return self._irreps[id]

    _cached_group_instance = None

    @classmethod
    def _generator(cls) -> 'Octahedral':
        if cls._cached_group_instance is None:
            cls._cached_group_instance = Octahedral()
    
        return cls._cached_group_instance


def _is_axis_aligned(v: np.ndarray, n: int, verbose: bool = False, ATOL=1e-7, RTOL = 1e-5) -> bool:

    norm = np.linalg.norm(v)
    v = v / norm

    if n == 2:
        # rotation of order 2

        # the rotation axis need to be aligned with one of the axes X, Y, Z or to
        # the bisector of a pair of axes XY, XZ, YZ
        # There are in total 6 + 12 possible vectors

        # remove sign ambiguity
        v = np.abs(v)

        axes = np.eye(3)
        bisectors = np.array([
            [1., 0., 1.],
            [1., 1., 0.],
            [0., 1., 1.],
        ]) / np.sqrt(2)

        ans = (
                # axes aligned
                np.allclose(v, axes[0], atol=ATOL, rtol=RTOL)
                or np.allclose(v, axes[1], atol=ATOL, rtol=RTOL)
                or np.allclose(v, axes[2], atol=ATOL, rtol=RTOL)
                # bisectors aligned
                or np.allclose(v, bisectors[0], atol=ATOL, rtol=RTOL)
                or np.allclose(v, bisectors[1], atol=ATOL, rtol=RTOL)
                or np.allclose(v, bisectors[2], atol=ATOL, rtol=RTOL)
        )

        if not ans and verbose:
            print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.')

        return ans

    elif n == 4:
        # rotation of order 4

        # the rotation axis need to be aligned with one of the axes X, Y, Z
        # There are in total 6 possible vectors

        # remove sign ambiguity
        v = np.abs(v)

        axes = np.eye(3)

        ans = (
               np.allclose(v, axes[0], atol=ATOL, rtol=RTOL)
            or np.allclose(v, axes[1], atol=ATOL, rtol=RTOL)
            or np.allclose(v, axes[2], atol=ATOL, rtol=RTOL)
        )

        if not ans and verbose:
            print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.')

        return ans

    elif n == 3:
        # rotation or order 3

        # the rotation axis need to pass through one of the vertices of the cube
        # There are in total 8 possible vectors

        # remove sign ambiguity
        v = np.abs(v)

        # since the vector is normalized, `v` should now be `(1, 1, 1)^T * 1/sqrt(3)`

        ans = np.allclose(v, 1./np.sqrt(3), atol=ATOL, rtol=RTOL)

        if not ans and verbose:
            print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.')

        return ans

    else:
        raise ValueError('The rotation order must be one of {2, 3, 4}.')


#############################################
# SUBGROUPS MAPS
#############################################

# C_N #####################################

def octa_to_cn(adj: GroupElement, cn: package.group.CyclicGroup, axis: np.ndarray):
    assert isinstance(adj.group, Octahedral)

    assert axis.shape == (3,)
    assert np.isclose(np.linalg.norm(axis), 1.)

    assert cn.order() in [2, 3, 4]

    assert _is_axis_aligned(axis, cn.order())

    def _map(e: GroupElement, cn=cn, adj=adj, axis=axis):
        octa = adj.group
        assert e.group == octa

        e = adj @ e @ (~adj)

        e = e.to('Q')

        v = e[:3]

        n = np.linalg.norm(v)

        if np.allclose(n, 0.):
            return cn.identity
        elif np.allclose(v / n, axis):
            # if the rotation is along the axis
            s, c = n, e[-1]
            theta = 2 * np.arctan2(s, c)
            try:
                return cn.element(theta, 'radians')
            except ValueError:
                return None
        else:
            return None

    return _map


def cn_to_octa(adj: GroupElement, cn: package.group.CyclicGroup, axis: np.ndarray):
    assert isinstance(adj.group, Octahedral)

    assert axis.shape == (3,)
    assert np.isclose(np.linalg.norm(axis), 1.)

    assert cn.order() in [2, 3, 4]

    assert _is_axis_aligned(axis, cn.order())

    def _map(e: GroupElement, cn=cn, adj=adj, axis=axis):
        assert e.group == cn
        octa = adj.group

        theta_2 = e.to('radians') / 2.

        q = np.empty(4)
        q[:3] = axis * np.sin(theta_2)
        q[-1] = np.cos(theta_2)

        return (~adj) @ octa.element(q, 'Q') @ adj

    return _map


#############################################
# Generate irreps
#############################################

from joblib import Memory
from package.group import __cache_path__

cache = Memory(__cache_path__, verbose=2)


@cache.cache(ignore=['octa'])
def _build_octa_irrep(octa: Octahedral, l: int):
    
    if l == -1:
        
        # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf
        
        # the matrix coefficients there are expressed wrt a different set of generators
        # we fist build this set of generators
        
        r3 = octa.generators[0]
        r = r3 @ r3 @ r3
        
        k = octa.elements[0]
        s = octa.generators[1]
        t = ~s @ k @ s @ r
        
        # Representation of `t`
        rho_t = np.array([
            [1., 0., 0.],
            [0., 0., 1.],
            [0., 1., 0.],
        ])
        
        # Representation of `k`
        rho_k = np.array([
            [1., 0., 0.],
            [0., -1., 0.],
            [0., 0., -1.],
        ])
        
        # Representation of `s`
        rho_s = np.array([
            [0., 1., 0.],
            [0., 0., 1.],
            [1., 0., 0.],
        ])
        
        #  https://arxiv.org/pdf/1110.6376.pdf defines the irrep `l = 1` (denoted by 3 there) as our
        #  `standard_representation`, which is expressed on a different basis than the Wigner D matrix with l=1.
        # Since `l=-1` (their 3') is defined as the tensor product between `l=1` and `l=3` (their 1')
        # we apply the inverse change of basis used in `standard_representation` to ensure that
        # `-1 = 1 \tensor 3` for us as well
        
        change_of_basis = np.array([
            [0, 0, 1],
            [1, 0, 0],
            [0, 1, 0]
        ])
        
        rho_t = change_of_basis.T @ rho_t @ change_of_basis
        rho_k = change_of_basis.T @ rho_k @ change_of_basis
        rho_s = change_of_basis.T @ rho_s @ change_of_basis
        
        generators = [
            (t, rho_t),
            (k, rho_k),
            (s, rho_s),
        ]
        
        return generate_irrep_matrices_from_generators(octa, generators)

    elif l == 2:

        # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf

        # the matrix coefficients there are expressed wrt a different set of generators
        # we fist build this set of generators

        r3 = octa.generators[0]
        r = r3 @ r3 @ r3

        k = octa.elements[0]
        s = octa.generators[1]
        t = ~s @ k @ s @ r

        # Representation of `t`
        rho_t = np.array([
            [0., 1.],
            [1., 0.],
        ])

        # Representation of `k`
        rho_k = np.array([
            [1., 0.],
            [0., 1.],
        ])

        # Representation of `s`
        rho_s = 0.5 * np.array([
            [-1., -np.sqrt(3)],
            [np.sqrt(3), -1.],
        ])

        generators = [
            (t, rho_t),
            (k, rho_k),
            (s, rho_s),
        ]
        
        return generate_irrep_matrices_from_generators(octa, generators)

    elif l == 3:

        # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf

        # the matrix coefficients there are expressed wrt a different set of generators
        # we fist build this set of generators

        r3 = octa.generators[0]
        r = r3 @ r3 @ r3

        k = octa.elements[0]
        s = octa.generators[1]
        t = ~s @ k @ s @ r

        # Representation of `t`
        rho_t = np.array([[-1.]])

        # Representation of `k`
        rho_k = np.array([[1.]])

        # Representation of `s`
        rho_s = np.array([[1.]])

        generators = [
            (t, rho_t),
            (k, rho_k),
            (s, rho_s),
        ]

        return generate_irrep_matrices_from_generators(octa, generators)

    else:
        raise ValueError()
