import torch
from ._irreps import Irreps
from ._cartesian import cartesian_3j

batch = 2
l1 = 3
l2 = 4
l3 = 2
# this l1 and l2 can produce irrep from l = 1 to 7
# take l = 2 as example
in1 = Irreps(f"{l1}e").randn(batch, -1)
in2 = Irreps(f"{l2}e").randn(batch, -1)

def ictp_example(in1, in2, cartesian_3j):
    einsum_str = "bi, bj, ijk->bk"
    return torch.einsum(einsum_str, in1, in2, cartesian_3j)

print("l=2", ictp_example(in1, in2, cartesian_3j(l1, l2, l3)).view(batch, *(3,)*l3))

