
from typing import Tuple, Union, Callable

from package.group import IrreducibleRepresentation, GroupElement, Group, Representation, directsum

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components


__all__ = ["HomSpace"]


class HomSpace:
    
    def __init__(self,
                 G: Group,
                 sgid: Tuple,
                 ):
        r"""
            An homogeneous space is defined as the quotient space between a group G and a subgroup H, identified by the
            input `sgid`

        Args:
            G (Group):
            sgid (tuple):
        """
        
        super(HomSpace, self).__init__()
        
        # Group:
        self.G = G
        
        self.H, self._inclusion, self._restriction = self.G.subgroup(sgid)

        # tuple:
        self.sgid = sgid
        
    def same_coset(self, g1: GroupElement, g2: GroupElement):
        
        assert g1.group == self.G
        assert g2.group == self.G
        
        d = ~g1 @ g2
        
        return self._restriction(d) is not None

    def basis(self,
              g: GroupElement,
              rho: Union[IrreducibleRepresentation, Tuple],
              psi: Union[IrreducibleRepresentation, Tuple]
              ) -> np.ndarray:
        r"""
        
        Args:
            g(GroupElement):
            rho:
            psi:

        Returns:

        """
        
        assert g.group == self.G
        
        if isinstance(rho, tuple):
            rho = self.G.irrep(*rho)
            
        if isinstance(psi, tuple):
            psi = self.H.irrep(*psi)
        
        assert isinstance(rho, IrreducibleRepresentation)
        assert isinstance(psi, IrreducibleRepresentation)
        
        assert rho.group == self.G
        assert psi.group == self.H
        
        # (rho.size, multiplicity of rho in Ind psi, psi.size)
        # B[:, j, :] is an intertwiner between f(e) \in V_psi and the j-th occurrence of rho in Ind psi
        #
        # B_0(g) = rho(g) @ B[:, :, 0]
        # contains the basis for f \in Ind psi interpreted as a scalar function f: G \to R
        # (as a subrepresentation of the regular repr of G)
        # i.e. it contains a basis for f(g)_0
        #
        # The full tensor B(g) = rho(g) @ B
        # is a basis for f interpreted as a Mackey function f: G \to V_psi
        B = self._dirac_kernel_ft(rho.id, psi.id)
        
        return np.einsum('oi, ijp->ojp', rho(g), B)
        
    def _dirac_kernel_ft(self, rho: Tuple, psi: Tuple, eps: float = 1e-9) -> np.ndarray:
        
        rho = self.G.irrep(*rho)
        psi = self.H.irrep(*psi)
        
        rho_H = rho.restrict(self.sgid)

        m_psi = 0
        for irrep in rho_H.irreps:
            if self.H.irrep(*irrep) == psi:
                m_psi += 1
                
        basis = np.zeros((rho.size, m_psi * psi.sum_of_squares_constituents, psi.size))
        
        # pick the arbitrary basis element e_i (i=0) for V_\psi
        i = 0
        
        p = 0
        j = 0
        
        column_mask = np.zeros(rho.size, dtype=np.bool)
        
        for irrep in rho_H.irreps:
            irrep = self.H.irrep(*irrep)
            
            if irrep == psi:
                w_i = (psi.endomorphism_basis()[:, i, :] **2).sum(axis=0)
                nonnull_mask = w_i > eps
                
                assert nonnull_mask.sum() == psi.sum_of_squares_constituents
                
                O_ij = np.einsum(
                    'kj,kab->ajb',
                    psi.endomorphism_basis()[:, i, nonnull_mask],
                    psi.endomorphism_basis(),
                )

                basis[p:p+irrep.size, j:j+psi.sum_of_squares_constituents, :] = O_ij
                column_mask[p:p+irrep.size] = nonnull_mask
                j += psi.sum_of_squares_constituents

            p += irrep.size
        
        if rho.sum_of_squares_constituents > 1:
        
            endom_basis = (
                    rho_H.change_of_basis_inv[column_mask, :]
                  @ rho.endomorphism_basis()
                  @ rho_H.change_of_basis[:, column_mask]
            )
            ortho = (endom_basis**2).sum(0) > eps

            assert ortho.sum() == column_mask.sum() * rho.sum_of_squares_constituents, (ortho, column_mask.sum(), rho.sum_of_squares_constituents)

            n, dependencies = connected_components(csgraph=csr_matrix(ortho), directed=False, return_labels=True)
            
            # check Frobenius' Reciprocity
            assert n * rho.sum_of_squares_constituents == m_psi * psi.sum_of_squares_constituents,\
                (n, rho.sum_of_squares_constituents, m_psi, psi.sum_of_squares_constituents, rho, psi)

            mask = np.zeros((ortho.shape[0]), dtype=np.bool)

            for i in range(n):
                columns = np.nonzero(dependencies == i)[0]
                assert len(columns) == rho.sum_of_squares_constituents
                selected_column = columns[0]
                mask[selected_column] = 1

            assert mask.sum() == n
            
            basis = basis[:, mask, :]
            
            assert basis.shape[1] == n
            
        basis = np.einsum('oi,ijp->ojp', rho_H.change_of_basis, basis)

        return basis

    def dimension_basis(self, rho: Tuple, psi: Tuple) -> Tuple[int, int, int]:
        r"""


        Args:
            rho:
            psi:

        Returns:

        """
        rho = self.G.irrep(*rho)
        psi = self.H.irrep(*psi)

        # Computing this restriction every time can be very expensive.
        # Representation.restrict(id) keeps a cache of the representations, so the restriction needs to be computed only
        # the first time it is called
        rho_H = rho.restrict(self.sgid)
    
        m_psi = 0
        for irrep in rho_H.irreps:
            if self.H.irrep(*irrep) == psi:
                m_psi += 1
        
        # Frobenius' Reciprocity theorem
        multiplicity = m_psi * psi.sum_of_squares_constituents / rho.sum_of_squares_constituents
        
        assert np.isclose(multiplicity, round(multiplicity))
        
        multiplicity = int(round(multiplicity))
        
        return rho.size, multiplicity, psi.size

    def _unit_test_basis(self):
        
        for rho in self.G.irreps():
            rho_H = rho.restrict(self.sgid)
            
            for psi in self.H.irreps():
                
                for _ in range(30):
                    g1 = self.G.sample()
                    g2 = self.G.sample()
                    
                    k_1 = self.basis(g1, rho, psi)
                    k_2 = self.basis(g2, rho, psi)
                    
                    assert k_1.shape == self.dimension_basis(rho.id, psi.id)
                    assert k_2.shape == self.dimension_basis(rho.id, psi.id)
                    
                    g12 = g2 @ (~g1)
                    
                    assert np.allclose(
                        k_2,
                        np.einsum('oi, ijp->ojp', rho(g12), k_1)
                    )

                for _ in range(30):
                    h = self.H.sample()
                    g = self.G.sample()
    
                    B = self.basis(g, rho, psi)
                    assert B.shape == self.dimension_basis(rho.id, psi.id)

                    Bh = np.einsum('ijp,pq->ijq', B, psi(h))
                    hB = np.einsum('oi,ijp->ojp', rho(g) @ rho_H(h) @ rho(g).T, B)

                    assert np.allclose(
                        Bh, hB
                    )
            
            if self.H.order() == 1:
                # when inducing from the trivial group, one obtains the regular representation of G
                # (up to a permutation of the columns)
                
                for _ in range(100):
                    g = self.G.sample()
                    
                    B = self.basis(g, rho, self.H.trivial_representation)
                    
                    rho_g = rho(g)[:, :rho.size // rho.sum_of_squares_constituents]
                    
                    # rho_g and B[..., 0] should be equal to each other up to a permutation of the columns
                    comparison = np.einsum('ij,ik->jk', rho_g, B[..., 0])
                    
                    # therefore the comparison matrix needs to be a permutation matrix
                    assert (np.isclose(comparison.sum(axis=0), 1.)).all()
                    assert (np.isclose(comparison.sum(axis=1), 1.)).all()
                    assert (np.isclose(comparison, 0.) | np.isclose(comparison, 1.)).all()



