import torch
from torch import nn, Tensor
import e3nn
from e3nn import o3

from typing import *


def get_vectors_at_l(signal: Tensor, irreps: o3.Irreps, l: int):
    ls_indices = torch.cat([torch.tensor(irreps.ls)[torch.tensor(irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(irreps.ls)))])
    return signal[:, ls_indices == l]

def get_random_wigner_D(lmax: int):
    rot_matrix = o3.rand_matrix()
    alpha, beta, gamma = o3.matrix_to_angles(rot_matrix)
    wigner = {}
    for l in range(lmax + 1):
        wigner[l] = o3.wigner_D(l, alpha, beta, gamma)
    return wigner

def get_wigner_D_from_rot_matrix(lmax: int, rot_matrix: Tensor):
    alpha, beta, gamma = o3.matrix_to_angles(rot_matrix)
    wigner = {}
    for l in range(lmax + 1):
        wigner[l] = o3.wigner_D(l, alpha, beta, gamma)
    return wigner

def get_wigner_D_from_alpha_beta_gamma(lmax: int, alpha: Tensor, beta: Tensor, gamma: Tensor):
    wigner = {}
    for l in range(lmax + 1):
        wigner[l] = o3.wigner_D(l, alpha, beta, gamma)
    return wigner

def rotate_signal(signal: Tensor, irreps: o3.Irreps, wigner: Dict):
    '''
    wigner must contain wigner-D matrices for all l's in irreps, otherwise a KeyError will be thrown
    '''
    wigner_ls = [wigner[l] for l in irreps.ls]
    rot_mat = torch.block_diag(*wigner_ls)
    rotated_signal = torch.matmul(signal, torch.t(rot_mat)) # Compute R = S * W^T --> 
    return rotated_signal

def is_equivariant(function: nn.Module,
                   irreps_in: o3.Irreps,
                   irreps_out: Optional[o3.Irreps] = None,
                   signal_in: Optional[Tensor] = None,
                   rtol: float = 1e-05,  # rtol for torch.allclose()
                   atol: float = 1e-08,  # atol for torch.allclose()
                   device: str = 'cpu'):
    '''
    rtol and atol may have to be relaxed to account for numerical errors of floating-point
    operations for the model computation-intensive functions. The default values are those of
    torch.allclose(), but they tend to be too strict.
    '''
    
    # if irreps_out is not provided then it is assumed that function has an irreps_out attribute,
    # and that attribute is used
    if irreps_out is None:
        irreps_out = function.irreps_out
    
    # if signal_in is not provided, sample random signal with batch_size of 1
    if signal_in is None:
        signal_in = irreps_in.randn(1, -1)
    
    lmax_in = max(irreps_in.ls)
    lmax_out = max(irreps_out.ls)
    
    rot_matrix = o3.rand_matrix()
    alpha, beta, gamma = o3.matrix_to_angles(rot_matrix)
    wigner = {}
    for l in range(max(lmax_in, lmax_out) + 1):
        wigner[l] = o3.wigner_D(l, alpha, beta, gamma)

    signal_out = function(signal_in.to(device)).detach().cpu()
    signal_out_rotated = rotate_signal(signal_out, irreps_out, wigner)
    
    signal_in_rotated = rotate_signal(signal_in, irreps_in, wigner)
    signal_rotated_out = function(signal_in_rotated.to(device)).detach().cpu()
    
    is_equiv = torch.allclose(signal_out_rotated, signal_rotated_out, rtol=rtol, atol=atol)
    mean_diff = torch.mean(torch.abs(signal_out_rotated - signal_rotated_out))
    
    return is_equiv, mean_diff




