from math import sqrt
from typing import Optional
from LorentzMACE.modules.utils import minkowski_norm, vectors_to_rep
from LieCG.CG_coefficients.CG_lorentz import cg_product
import torch

class SphericalHarmonics(torch.nn.Module):
    def __init__(
        self,
        cg_dict,
        max_ell: int,
    ) -> None:
        super().__init__()
        self.max_ell=max_ell
        self.cg_dict=cg_dict

    def forward(self,
        vector: torch.Tensor, #[batch,4]
    ) -> torch.Tensor:

        vector_rep = vectors_to_rep(vector)
        vector_rep_dict = {(1,1) : torch.view_as_real(vector_rep.unsqueeze(-2)).permute(-1,0,1,2)}
        sh = [torch.ones(vector_rep.shape[:-1] + (1,), device=vector_rep.device, dtype=vector_rep.dtype)]
        sh += [vector_rep]
        sh_dict = {(1,1) : torch.view_as_real(vector_rep.unsqueeze(-2)).permute(-1,0,1,2)}
        for l in range(2, self.max_ell + 1):
            ylk = cg_product(self.cg_dict, {(l - 1, l - 1): sh_dict[(l-1,l-1)]}, vector_rep_dict, maxdim=l + 1)[(l, l)] #recode with tensor product
            ylk *= sqrt(2 * l / (l + 1)) # normalization
            sh_dict[(l,l)] = ylk
            sh.append(torch.view_as_complex(ylk.permute(1,2,3,0).contiguous()).squeeze(-2))
        return torch.cat(sh,dim=-1)