import torch
import numpy as np
from math import inf, sqrt
from scipy.special import factorial
import itertools

import logging
logger = logging.getLogger(__name__)


#From Lorentz group equivariant network Bogatskiy
#Copyright (C) 2019, The University of Chicago, Brandon Anderson, 
#Alexander Bogatskiy, David Miller, Jan Offermann, and Risi Kondor. 

#This program is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 2 of the License, or
#(at your option) any later version.

#This program is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.

#A copy of the GNU General Public License is provided below.

class CGDict():
    """
    A dictionary of Clebsch-Gordan (CG) coefficients to be used in CG operations.
    The CG coefficients
    .. math::
        \langle \ell_1, m_1, \ell_2, m_2 | l, m \rangle
    are used to decompose the tensor product of two
    irreps of maximum weights :math:`\ell_1` and :math:`\ell_2` into a direct
    sum of irreps with :math:`\ell = |\ell_1 -\ell_2|, \ldots, (\ell_1 + \ell_2)`.
    The coefficients for each :math:`\ell_1` and :math:`\ell_2`
    are stored as a :math:`D \times D` matrix :math:`C_{\ell_1,\ell_2}` ,
    where :math:`D = (2\ell_1+1)\times(2\ell_2+1)`.
    The module has a dict-like interface with keys :math:`(l_1, l_2)` for
    :math:`\ell_1, l_2 \leq l_{\rm max}`. Each value is a matrix of shape
    :math:`D \times D`, where :math:`D = (2l_1+1)\times(2l_2+1)`.
    The matrix has elements.
    Parameters
    ----------
    maxdim: int
        Maximum weight for which to calculate the Clebsch-Gordan coefficients.
        This refers to the maximum weight for the ``input tensors``, not the
        output tensors.
    transpose: bool, optional
        Transpose the CG coefficient matrix for each :math:`(\ell_1, \ell_2)`.
        This cannot be modified after instantiation.
    device: `torch.torch.device`, optional
        Device of CG dictionary.
    dtype: `torch.torch.dtype`, optional
        Data type of CG dictionary.
    """

    def __init__(self, maxdim=None, transpose=True, dtype=torch.get_default_dtype(), device=None):

        self.dtype = dtype
        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = device
        self._transpose = transpose
        self._maxdim = None
        self._cg_dict = {}

        if maxdim is not None:
            self.update_maxdim(maxdim)

    @property
    def transpose(self):
        """
        Use "transposed" version of CG coefficients.
        """
        return self._transpose

    @property
    def maxdim(self):
        """
        Maximum weight for CG coefficients.
        """
        return self._maxdim

    def update_maxdim(self, new_maxdim):
        """
        Update maxdim to a new (possibly larger) value. If the new_maxdim is
        larger than the current maxdim, new CG coefficients should be calculated
        and the cg_dict will be updated.
        Otherwise, do nothing.
        Parameters
        ----------
        new_maxdim: int
            New maximum weight.
        Return
        ------
        self: `CGDict`
            Returns self with a possibly updated self.cg_dict.
        """
        # If self is already initialized, and maxdim is sufficiently large, do nothing
        if self and (self.maxdim >= new_maxdim):
            return self

        # If self is false, old_maxdim = 0 (uninitialized).
        # old_maxdim = self.maxdim if self else 0

        # Otherwise, update the CG coefficients.
        cg_dict_new = _gen_cg_dict(new_maxdim, existing_keys=self._cg_dict.keys())
        cg_dict_new = {key: {irrep: cg_tens.reshape(-1, cg_tens.shape[-1])
                             for irrep, cg_tens in val.items()}
                       for key, val in cg_dict_new.items()}
        if self.transpose:
            cg_dict_new = {key: {irrep: cg_mat.permute(1, 0)
                                 for irrep, cg_mat in val.items()}
                           for key, val in cg_dict_new.items()}

        # Ensure elements of new CG dict are on correct device.
        cg_dict_new = {key: {irrep: cg_mat.to(dtype=self.dtype, device=self.device)
                             for irrep, cg_mat in val.items()}
                       for key, val in cg_dict_new.items()}

        # Now update the CG dict, and also update maxdim
        self._cg_dict.update(cg_dict_new)

        self._maxdim = new_maxdim

        return self

    def to(self, dtype=None, device=None):
        """
        Convert CGDict() to a new device/dtype.
        Parameters
        ----------
        device : `torch.torch.device`, optional
            Device to move the cg_dict to.
        dtype : `torch.torch.dtype`, optional
            Data type to convert the cg_dict to.
        """
        if dtype is None and device is None:
            pass
        elif dtype is None and device is not None:
            self._cg_dict = {key: {irrep: cg_mat.to(
                device=device) for irrep, cg_mat in val.items()} for key, val in self._cg_dict.items()}
            self.device = device
        elif dtype is not None and device is None:
            self._cg_dict = {key: val.to(dtype=dtype) for key, val in self._cg_dict.items()}
            self.dtype = dtype
        elif dtype is not None and device is not None:
            self._cg_dict = {key: {irrep: cg_mat.to(device=device, dtype=dtype)
                                   for irrep, cg_mat in val.items()}
                             for key, val in self._cg_dict.items()}
            self.device, self.dtype = device, dtype
        return self

    def keys(self):
        return self._cg_dict.keys()

    def values(self):
        return self._cg_dict.values()

    def items(self):
        return self._cg_dict.items()

    def __getitem__(self, idx):
        if not self:
            raise ValueError('CGDict() not initialized. Either set maxdim, or use update_maxdim()')
        return self._cg_dict[idx]

    def __bool__(self):
        """
        Check to see if CGDict has been properly initialized, since :maxdim=-1: initially.
        """
        return self.maxdim is not None


def _gen_cg_dict(maxdim, existing_keys=None):
    '''
    Outputs a dictionary of tables of CG coefficients for the Lorentz group
    up to the given dimension maxdim of the G irrep subcomponents
    (every irrep of Lorentz group is V1 x V2, where V1 and V2 are irreps of G).
    Keys are tuples of labels (irrep T1, irrep T2) for irreps (which are themselves tuples of integers),
    and the values are again dictionaries of the form (irrep T): matrix,
    where the matrix is rectangular and maps irrep T into T1 x T2.
    This matrix is in fact more naturally stored as a torch.tensor of rank 3.
    Elements of an irrep of Lorentz group whose label is (k,n) are stored as vectors of size (k+1)*(n+1)
    which are concatenations of a set of vectors, exactly one for each l going from abs(k-n)/2 to (k+n)/2,
    and the size of the vector corresponding to l is 2*l+1. These sub-vectors belong to irreps of G.
    Therefore the dictionary values are tensors of shape ( (k1+1)*(n1+1), (k2+1)*(n2+1), (k+1)*(n+1) ).
    If we concatenate all such tensors for given (k1,n1,k2,n2), we get an orthogonal
    transformation from sum(T) to T1xT2, which is the CG operation done in cg_product().
    '''
    cg_dict = {}
    # print("gen_cg_dict called with maxdim =", maxdim)

    fastcgmat = memoize(clebschSU2mat)

    for k1, n1, k2, n2 in itertools.product(range(maxdim), repeat=4):
        if ((k1, n1), (k2, n2)) in existing_keys:
            continue
        cg_dict.setdefault(((k1, n1), (k2, n2)), {})
        kmin, kmax = abs(k1 - k2), k1 + k2
        nmin, nmax = abs(n1 - n2), n1 + n2
        # dim1, dim2 = (k1 + 1) * (n1 + 1), (k2 + 1) * (n2 + 1)
        for k, n in itertools.product(range(kmin, kmax + 1, 2), range(nmin, nmax + 1, 2)):
            cg_dict[((k1, n1), (k2, n2))][(k, n)] = clebschmat((k1, n1), (k2, n2), (k, n), fastcgmat=fastcgmat).clone().detach()

    return cg_dict


def memoize(func):  # create a cached version of any function for fast repeated use
    cache = dict()

    def memoized_func(*args):
        if args in cache:
            return cache[args]
        result = func(*args)
        cache[args] = result
        return result

    return memoized_func


def clebschSU2mat(j1, j2, j3):

    mat = np.zeros((int(2 * j1 + 1), int(2 * j2 + 1), int(2 * j3 + 1)))
    if int(2 * j3) in range(int(2 * abs(j1 - j2)), int(2 * (j1 + j2)) + 1, 2):
        for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)):
            for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)):
                if abs(m1 + m2) <= j3:
                    mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = clebschSU2((j1, m1), (j2, m2), (j3, m1 + m2))
    return np.array(mat)


def clebschmat(rep1, rep2, rep, fastcgmat=memoize(clebschSU2mat)):
    """
    Compute the whole rank 3 tensor of CG coefficients over (l1,m1),(l2,m2),(l,m) (implemented via fast matrix multiplication)
    """
    k1, n1 = rep1
    k2, n2 = rep2
    k, n = rep
    B1 = np.concatenate([fastcgmat(k / 2, n / 2, i / 2)
                         for i in range(abs(k - n), k + n + 1, 2)], axis=-1)
    B2a = fastcgmat(k1 / 2, k2 / 2, k / 2)
    B2b = fastcgmat(n1 / 2, n2 / 2, n / 2)
    B3a = np.concatenate([fastcgmat(k1 / 2, n1 / 2, i1 / 2)
                          for i1 in range(abs(k1 - n1), k1 + n1 + 1, 2)], axis=-1)
    B3b = np.concatenate([fastcgmat(k2 / 2, n2 / 2, i2 / 2)
                          for i2 in range(abs(k2 - n2), k2 + n2 + 1, 2)], axis=-1)
    H = np.einsum('cab', np.einsum('abc,dea,ghb,dgk,ehn', B1, B2a, B2b, B3a, B3b))
    return torch.tensor(H).type(torch.complex128)


def clebsch(idx1, idx2, idx):
    """
    Calculate a single Clebsch-Gordan coefficient
    for SL(2,C) coupling (k1,n1,j1,m1) and (k2,n2,j2,m2) to give (k,n,j,m).
    We will never use this in the network.
    """
    fastcg = clebschSU2

    k1, n1, j1, m1 = idx1
    k2, n2, j2, m2 = idx2
    k, n, j, m = idx

    if int(2 * j1) not in range(abs(k1 - n1), k1 + n1 + 1, 2):
        print(idx1, idx2, idx)
        raise ValueError('Invalid value of l1')
    if int(2 * j2) not in range(abs(k2 - n2), k2 + n2 + 1, 2):
        print(idx1, idx2, idx)
        raise ValueError('Invalid value of l2')
    if int(2 * j) not in range(abs(k - n), k + n + 1, 2):
        print(idx1, idx2, idx)
        raise ValueError('Invalid value of l')
    if m != m1 + m2:
        return 0

    H = sum(fastcg((k / 2, mm1 + mm2), (n / 2, m - mm1 - mm2), (j, m)) *
            fastcg((k1 / 2, mm1), (k2 / 2, mm2), (k / 2, mm1 + mm2)) *
            fastcg((n1 / 2, m1 - mm1), (n2 / 2, m2 - mm2), (n / 2, m - mm1 - mm2)) *
            fastcg((k1 / 2, mm1), (n1 / 2, m1 - mm1), (j1, m1)) *
            fastcg((k2 / 2, mm2), (n2 / 2, m2 - mm2), (j2, m2))
            for mm1 in (x / 2 for x in set(range(-k1, k1 + 1, 2)).intersection(set(range(int(2 * m1 - n1), int(2 * m1 + n1 + 1), 2))))
            for mm2 in (x / 2 for x in set(range(-k2, k2 + 1, 2)).intersection(
                set(range(int(2 * m2 - n2), int(2 * m2 + n2 + 1), 2))).intersection(
                    set(range(int(2 * m - n - 2 * mm1), int(2 * m + n - 2 * mm1 + 1), 2))).intersection(
                        set(range(int(- k - 2 * mm1), int(k - 2 * mm1 + 1), 2))))
            )
    return H


# clebschSU2
# Taken from http://qutip.org/docs/3.1.0/modules/qutip/utilities.html

# This file is part of QuTiP: Quantum Toolbox in Python.
#
#    Copyright (c) 2011 and later, Paul D. Nation and Robert J. Johansson.
#    All rights reserved.
#
#    Redistribution and use in source and binary forms, with or without
#    modification, are permitted provided that the following conditions are
#    met:
#
#    1. Redistributions of source code must retain the above copyright notice,
#       this list of conditions and the following disclaimer.
#
#    2. Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#    3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names
#       of its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
#    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
#    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
#    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
#    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
#    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
#    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
#    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
###############################################################################

def clebschSU2(idx1, idx2, idx3):
    """Calculates the Clebsch-Gordon coefficient
    for SU(2) coupling (j1,m1) and (j2,m2) to give (j3,m3).
    Parameters
    ----------
    j1 : float
        Total angular momentum 1.
    j2 : float
        Total angular momentum 2.
    j3 : float
        Total angular momentum 3.
    m1 : float
        z-component of angular momentum 1.
    m2 : float
        z-component of angular momentum 2.
    m3 : float
        z-component of angular momentum 3.
    Returns
    -------
    cg_coeff : float
        Requested Clebsch-Gordan coefficient.
    """
    j1, m1 = idx1
    j2, m2 = idx2
    j3, m3 = idx3

    if m3 != m1 + m2:
        return 0
    vmin = int(np.max([-j1 + j2 + m3, -j1 + m1, 0]))
    vmax = int(np.min([j2 + j3 + m1, j3 - j1 + j2, j3 + m3]))

    C = np.sqrt((2.0 * j3 + 1.0) * factorial(j3 + j1 - j2) * factorial(j3 - j1 + j2) * factorial(j1 + j2 - j3) * factorial(j3 + m3) * factorial(j3 - m3) /
                (factorial(j1 + j2 + j3 + 1) * factorial(j1 - m1) * factorial(j1 + m1) * factorial(j2 - m2) * factorial(j2 + m2)))
    S = 0
    for v in range(vmin, vmax + 1):
        S += (-1.0) ** (v + j2 + m2) / factorial(v) * factorial(j2 + j3 + m1 - v) * factorial(j1 - m1 + v) / \
            factorial(j3 - j1 + j2 - v) / \
            factorial(j3 + m3 - v) / \
            factorial(v + j1 - j2 - m3)
    C = C * S
    return C

    
def cg_product(cg_dict, 
               rep1, 
               rep2, maxdim=inf, aggregate=False, ignore_check=False):
    """
    Explicit function to calculate the Clebsch-Gordan product.
    See the documentation for CGProduct for more information.
    rep1 : list of :obj:`torch.Tensors`
        First :obj:`GVector` in the CG product
    rep2 : list of :obj:`torch.Tensors`
        First :obj:`GVector` in the CG product
    maxdim : :obj:`int`, optional
        Minimum weight to include in CG Product
    aggregate : :obj:`bool`, optional
        Apply an "aggregation" operation, or a pointwise convolution
        with a :obj:`GVector` as a filter.
    cg_dict : :obj:`CGDict`, optional
        Specify a Clebsch-Gordan dictionary. If not specified, one will be
        generated automatically at runtime based upon maxdim.
    ignore_check : :obj:`bool`
        Ignore GVec initialization check. Necessary for current implementation
        of :obj:`zonal_functions`. Use with caution.
    """
  
    maxk1 = max({key[0] for key in rep1.keys()})
    maxn1 = max({key[1] for key in rep1.keys()})
    maxk2 = max({key[0] for key in rep2.keys()})
    maxn2 = max({key[1] for key in rep2.keys()})
    maxDim = min(max(maxk1 + maxk2, maxn1 + maxn2) + 1, maxdim)

    if (cg_dict.maxdim < maxDim) or (cg_dict.maxdim < max(maxk1, maxn1, maxk2, maxn2)):
        raise ValueError('CG Dictionary maxdim ({}) not sufficiently large for (maxdim, L1, L2) = ({} {} {})'.format(cg_dict.maxdim, maxdim, max1, max2))
    assert(cg_dict.transpose), 'This operation uses transposed CG coefficients!'


    new_rep = {}

    for (k1, n1), irrep1 in rep1.items():
        for (k2, n2), irrep2 in rep2.items():
            if max(k1, n1, k2, n2) > maxDim - 1 or irrep1.shape[-2] == 0 or irrep2.shape[-2] == 0:
                continue
            # cg_mat, aka H, is initially a dictionary {(k,n):rectangular matrix},
            # which when flattened/stacked over keys becomes an orthogonal square matrix
            # we create a sorted list of keys first and then stack the rectangular matrices over keys
            cg_mat_keys = [(k, n) for k in range(abs(k1 - k2), min(maxdim, k1 + k2 + 1), 2) for n in range(abs(n1 - n2), min(maxdim, n1 + n2 + 1), 2)]
            cg_mat = torch.cat([cg_dict[((k1, n1), (k2, n2))][key] for key in cg_mat_keys], -2)
            # Pairwise tensor multiply parts, loop over atom parts accumulating each.
            irrep_prod = complex_kron_product(irrep1, irrep2, aggregate=aggregate)
            # Multiply by the CG matrix, effectively turning the product into stacked irreps. Channels are preserved
            # Have to add a dummy index because matmul acts over the last two dimensions, so the vector dimension on the right needs to be -2
            cg_decomp = torch.squeeze(torch.matmul(cg_mat, torch.unsqueeze(irrep_prod, -1)), -1)
            # Split the result into a list of separate irreps
            split = [(k + 1) * (n + 1) for (k, n) in cg_mat_keys]
            cg_decomp = torch.split(cg_decomp, split, dim=-1)
            # Add the irreps to the dictionary entries, first keeping the channel dimension as a list

            for idx, key in enumerate(cg_mat_keys):
                new_rep.setdefault(key, [])
                new_rep[key].append(cg_decomp[idx])
    # at the end concatenate over the channel dimension back into torch tensors

    new_rep = {key: torch.cat(val, dim=-2) for key, val in new_rep.items()}

    # TODO: Rewrite so ignore_check not necessary
    return new_rep


def complex_kron_product(z1, z2, aggregate=False):
    """
    Take two complex matrix tensors z1 and z2, and take their tensor product.
    Parameters
    ----------
    z1 : :class:`torch.Tensor`
        Tensor of shape batch1 x M1 x N1 x 2.
        The last dimension is the complex dimension.
    z1 : :class:`torch.Tensor`
        Tensor of shape batch2 x M2 x N2 x 2.
    aggregate: :class:`bool`
        Apply aggregation/point-wise convolutional filter. Must have batch1 = B x A x A, batch2 = B x A
    Returns
    -------
    z1 : :class:`torch.Tensor`
        Tensor of shape batch x (M1 x M2) x (N1 x N2) x 2
    """
    s1 = z1.shape
    s2 = z2.shape
    assert(len(s1) >= 3), 'Must have batch dimension!'
    assert(len(s2) >= 3), 'Must have batch dimension!'

    b1, b2 = s1[1:-2], s2[1:-2]  # b can contantain batch and atom dimensions, not channel/multiplicity
    s1, s2 = s1[-2:], s2[-2:]  # s contains the channel dimension and the actual vector dimension
    if not aggregate:
        assert(b1 == b2), 'Batch sizes must be equal! {} {}'.format(b1, b2)
        b = b1
    else:
        if (len(b1) == 3) and (len(b2) == 2):
            assert(b1[0] == b2[0]), 'Batch sizes must be equal! {} {}'.format(b1, b2)
            assert(b1[2] == b2[1]), 'Neighborhood sizes must be equal! {} {}'.format(b1, b2)

            z2 = z2.unsqueeze(2)
            b2 = z2.shape[1:-2]
            b = b1

            agg_sum_dim = 3

        elif (len(b1) == 2) and (len(b2) == 3):
            assert(b2[0] == b1[0]), 'Batch sizes must be equal! {} {}'.format(b1, b2)
            assert(b2[2] == b1[1]), 'Neighborhood sizes must be equal! {} {}'.format(b1, b2)

            z1 = z1.unsqueeze(2)
            b1 = z1.shape[1:-2]
            b = b2

            agg_sum_dim = 3

        else:
            raise ValueError('Batch size error! {} {}'.format(b1, b2))

    # Treat the channel index like a "batch index".
    assert(s1[0] == s2[0]), 'Number of channels must match! {} {}'.format(s1[0], s2[0])

    s12 = (4,) + b + (s1[0], s1[1] * s2[1])

    # here we add extra empty dimensions to construct a tensor product
    s10 = (2, 1) + b1 + (s1[0],) + torch.Size([s1[1], 1])
    s20 = (1, 2) + b2 + (s1[0],) + torch.Size([1, s2[1]])

    z = (z1.view(s10) * z2.view(s20))
    z = z.contiguous().view(s12)
    if aggregate:
        # Aggregation is sum over aggregation sum dimension defined above
        z = z.sum(agg_sum_dim, keepdim=False)

    # convert the tensor product of the two complex dimensions into an actual multiplication of complex numbers
    zrot = torch.tensor([[1., 0., 0., -1.], [0., 1., 1., 0.]], dtype=z.dtype, device=z.device)
    z = torch.einsum("ab,b...->a...", zrot, z)
    return z




def rotate_part(D, z, side='left', autoconvert=True, conjugate=False):
    """ Apply a D matrix using complex broadcast matrix multiplication. """
    if autoconvert:
        D = D.to(z.device, z.dtype)
    if conjugate:
        D = dagger(D)
    Dr, Di = D.unbind(0)
    zr, zi = z.unbind(0)

    if side == 'left':
        return torch.stack((torch.matmul(zr, Dr) + torch.matmul(zi, Di),
                            - torch.matmul(zr, Di) + torch.matmul(zi, Dr)), 0)
    elif side == 'right':
        return torch.stack((torch.matmul(Dr, zr) + torch.matmul(Di, zi),
                            - torch.matmul(Di, zr) + torch.matmul(Dr, zi)), 0)
    else:
        raise ValueError('Must choose side: left/right.')


def rotate_rep(rep, alpha, beta, gamma, side='left', conjugate=False, cg_dict=None):
    """ Apply a part-wise left/right sided D-matrix to a (matrix) representation. """
    device, dtype = rep.device, rep.dtype
    return rep.__class__({key: rotate_part(LorentzD(key, alpha, beta, gamma, cg_dict=cg_dict, device=device, dtype=dtype), part, side=side, conjugate=conjugate) for key, part in rep.items()})


def create_J(j):
    mrange = -np.arange(-j, j)
    jp_diag = np.sqrt((j + mrange) * (j - mrange + 1))
    Jp = np.diag(jp_diag, k=1)
    Jm = np.diag(jp_diag, k=-1)
    # Jx = (Jp + Jm) / complex(2, 0)
    # Jy = -(Jp - Jm) / complex(0, 2)
    Jz = np.diag(-np.arange(-j, j + 1))
    Id = np.eye(2 * j + 1)
    return Jp, Jm, Jz, Id


def create_Jy(j):
    mrange = -np.arange(-j, j)
    jp_diag = np.sqrt((j + mrange) * (j - mrange + 1))
    Jp = np.diag(jp_diag, k=1)
    Jm = np.diag(jp_diag, k=-1)
    Jy = -(Jp - Jm) / complex(0, 2)
    return Jy


def create_Jx(j):
    mrange = -np.arange(-j, j)
    jp_diag = np.sqrt((j + mrange) * (j - mrange + 1))
    Jp = np.diag(jp_diag, k=1)
    Jm = np.diag(jp_diag, k=-1)
    Jx = (Jp + Jm) / complex(2, 0)
    return Jx


def littled(j, beta):
    Jy = create_Jy(j)
    evals, evecs = np.linalg.eigh(Jy)
    evecsh = evecs.conj().T
    evals_exp = np.diag(np.exp(1j * beta * evals))
    d = np.matmul(np.matmul(evecs, evals_exp), evecsh)
    return d


def WignerD(j, alpha, beta, gamma, numpy_test=False, dtype=torch.float, device=torch.device('cpu')):
    d = littled(j, beta)

    Jz = np.arange(-j, j + 1)
    Jzl = np.expand_dims(Jz, 1)

    # np.multiply() broadcasts, so this isn't actually matrix multiplication, and 'left'/'right' are lies
    left = np.exp(1j * alpha * Jzl)
    right = np.exp(1j * gamma * Jz)

    D = left * d * right

    if not numpy_test:
        D = complex_from_numpy(D, dtype=dtype, device=device)

    return D


def LorentzD(key, alpha, beta, gamma, numpy_test=False, dtype=torch.float, device=torch.device('cpu'), cg_dict=None):

    (k, n) = key
    if cg_dict is None:
        cg_dict = CGDict(maxdim=max(k, n) + 1, transpose=True, dtype=dtype, device=device)._cg_dict

    D = complex_tensor_prod(WignerD(k / 2, alpha, beta, gamma, numpy_test=numpy_test, dtype=dtype, device=device),
                            conj(WignerD(n / 2, -alpha, beta, -gamma, numpy_test=numpy_test, dtype=dtype, device=device)))
    cg_mat = cg_dict[((k, 0), (0, n))][(k, n)]
    D_re = torch.matmul(torch.matmul(cg_mat, D.unbind(0)[0]), cg_mat.t())
    D_im = torch.matmul(torch.matmul(cg_mat, D.unbind(0)[1]), cg_mat.t())
    D = torch.stack((D_re, D_im), 0)
    return D

def _gen_rot(angles, device=torch.device('cpu'), dtype=torch.float, cg_dict=None):

	# save the dictionary of Lorentz-D matrices
	D = LorentzD((1, 1), *angles, device=device, dtype=dtype, cg_dict=cg_dict)
	# compute the Lorentz matrix in cartesian coordinates
	cartesian4=torch.tensor([[[1,0,0,0],[0,1/sqrt(2.),0,0],[0,0,0,1],[0,-1/sqrt(2.),0,0]],
                            [[0,0,0,0],[0,0,-1/sqrt(2.),0],[0,0,0,0],[0,0,-1/sqrt(2.),0]]],device=device, dtype=dtype)
	cartesian4H=torch.tensor([[[1,0,0,0],[0,1/sqrt(2.),0,0],[0,0,0,1],[0,-1/sqrt(2.),0,0]],
                            [[0,0,0,0],[0,0,1/sqrt(2.),0],[0,0,0,0],[0,0,1/sqrt(2.),0]]],device=device, dtype=dtype).permute(0,2,1)
	R = torch.stack((D[0].matmul(cartesian4[0])-D[1].matmul(cartesian4[1]), D[0].matmul(cartesian4[1]) + D[1].matmul(cartesian4[0])))
	R = cartesian4H[0].matmul(R[0]) - cartesian4H[1].matmul(R[1])
	return R

def dagger(D):
    conj = torch.tensor([1, -1], dtype=D.dtype, device=D.device).view(2, 1, 1)
    D = (D * conj).permute((0, 2, 1))
    return D


def conj(D):
    conj = torch.tensor([1, -1], dtype=D.dtype, device=D.device).view(2, 1, 1)
    D = D * conj
    return D


def complex_from_numpy(z, dtype=torch.float, device=torch.device('cpu')):
    """ Take a numpy array and output a complex array of the same size. """
    zr = torch.from_numpy(z.real).to(dtype=dtype, device=device)
    zi = torch.from_numpy(z.imag).to(dtype=dtype, device=device)

    return torch.stack((zr, zi), 0)


def complex_tensor_prod(d1, d2):
    d1_re, d1_im = d1.unbind(0)
    d2_re, d2_im = d2.unbind(0)
    s1 = d1.shape[1:]
    s2 = d2.shape[1:]
    assert len(s1) == 2 and len(
        s2) == 2, "Both tensors must be of rank 2 (and complex)!"
    d_re = d1_re.view(s1[0], 1, s1[1], 1) * d2_re.view(1, s2[0], 1, s2[1]) - \
        d1_im.view(s1[0], 1, s1[1], 1) * d2_im.view(1, s2[0], 1, s2[1])
    d_im = d1_re.view(s1[0], 1, s1[1], 1) * d2_im.view(1, s2[0], 1, s2[1]) + \
        d1_im.view(s1[0], 1, s1[1], 1) * d2_re.view(1, s2[0], 1, s2[1])
    return torch.stack((d_re, d_im), 0).view(2, s1[0] * s2[0], s1[1] * s2[1])