from LieCG import so13
from LieCG.CG_coefficients.CG_lorentz import _gen_rot
import numpy as np
import torch_geometric
import torch

from LorentzMACE import  data, modules

torch.set_default_dtype(torch.float64)


Ri = _gen_rot((0, 0.1, 0.1j),dtype=torch.double)

positions = torch.randn(3,4)

positions_rot = torch.einsum("...b, ba->...a", positions, Ri)

config = data.utils.Configuration(
    positions=np.array(positions),
    signal=np.array([1]),
    attributes=None,
)

config_rot = data.utils.Configuration(
    positions=np.array(positions_rot),
    signal=np.array([1]),
    attributes=None,
)

def test_lorentz_model():
    model_config =  dict(
        r_max=4,
        num_bessel=8,
        num_polynomial_cutoff=6,
        max_ell=3,
        radial_basis_cls=modules.basis_classes['RadialLorentzianEmbeddingBlock'],
        interaction_cls=modules.interaction_classes['ComplexAgnosticResidualInteractionBlock'],
        interaction_cls_first=modules.interaction_classes['ComplexAgnosticResidualInteractionBlock'],
        num_interactions=3,
        num_elements=1,
        hidden_irreps=so13.Lorentz_Irreps('32x(0,0) + 32x(1,1)'),
        readout_irreps=so13.Lorentz_Irreps('2x(0,0)'),
        use_cutoff=False,
        avg_num_neighbors=8,
        gate=torch.nn.functional.silu,
        scale=0.005,
        device='cpu',
        MLP_irreps=so13.Lorentz_Irreps('16x(0,0)'),    
    )
    model = modules.LorentzBOTNet(**model_config)

    atomic_data = data.AtomicData.from_config(config, cutoff_in=0.0, cutoff_out=1000)
    atomic_data2 = data.AtomicData.from_config(config_rot, cutoff_in=0.0, cutoff_out=1000)

    data_loader = torch_geometric.data.DataLoader(
        dataset=[atomic_data, atomic_data2],
        batch_size=2,
        shuffle=True,
        drop_last=False,
    )
    batch = next(iter(data_loader)).to('cpu')

    output = model(batch)




