from LorentzMACE.modules.irreps_tools import tp_out_irreps_with_instructions
from LieCG import so13
from LieCG.CG_coefficients.CG_lorentz import CGDict, cg_product
import torch
import torch.utils.benchmark as benchmark


def test_tp():
    irreps_in1 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_in2 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_target = so13.Lorentz_Irreps('1x(2,2)')

    irreps_out,instruction = tp_out_irreps_with_instructions(irreps_in1, irreps_in2, irreps_target)

    tp = so13.TensorProduct(irreps_in1,
                    irreps_in2,
                    irreps_out,
                    instruction,
                    shared_weights=False,
                    internal_weights=False,
                    use_complex=True)

    x = irreps_in1.randn(1,-1,dtype=torch.complex128)
    y = irreps_in2.randn(1,-1,dtype=torch.complex128)
    x_dict = {(1,1) : torch.view_as_real(x.unsqueeze(-2)).permute(-1,0,1,2)}
    y_dict = {(1,1) : torch.view_as_real(y.unsqueeze(-2)).permute(-1,0,1,2)}

    cg_dict =  CGDict(4)

    out_prod = cg_product(cg_dict,x_dict,y_dict,maxdim=3)[(2,2)]
    out_prod = torch.view_as_complex(out_prod.permute(1,2,3,0).contiguous()).squeeze(-2)

    out_tp = tp(x,y,torch.ones(1,1,dtype=torch.complex128))

    torch.allclose(out_tp*(1/3),out_prod)
test_tp()


def test_tp_real():
    irreps_in1 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_in2 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_target = so13.Lorentz_Irreps('1x(2,2)')

    irreps_out,instruction = tp_out_irreps_with_instructions(irreps_in1, irreps_in2, irreps_target)

    tp_complex = so13.TensorProduct(irreps_in1,
                    irreps_in2,
                    irreps_out,
                    instruction,
                    shared_weights=False,
                    internal_weights=False,
                    use_complex=True)
    
    tp_real = so13.TensorProduct(irreps_in1,
                    irreps_in2,
                    irreps_out,
                    instruction,
                    shared_weights=False,
                    internal_weights=False,
                    use_complex=False)

    x = irreps_in1.randn(1,-1,dtype=torch.complex128)
    y = irreps_in2.randn(1,-1,dtype=torch.complex128)
    out_tp_c = tp_complex(x,y,torch.ones(1,1,dtype=torch.complex128))
    
    out_tp_real = tp_real(x.real,y.real,torch.ones(1,1)) - tp_real(x.imag,y.imag,torch.ones(1,1))
    out_tp_imag = tp_real(x.imag,y.real,torch.ones(1,1)) + tp_real(x.real,y.imag,torch.ones(1,1))
    
    out_tp_r = torch.view_as_complex(torch.stack((out_tp_real,out_tp_imag),dim=-1))
    
    torch.allclose(out_tp_r,out_tp_c)
test_tp_real()

def test_fc_tp():
    irreps_in1 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_in2 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_target = so13.Lorentz_Irreps('1x(2,2)')

    fc = so13.FullyConnectedTensorProduct(irreps_in1,irreps_in2,irreps_target)

    x = irreps_in1.randn(1,-1,dtype=torch.complex128)
    y = irreps_in2.randn(1,-1,dtype=torch.complex128)

    out = fc(x,y)

    print(out.size())
test_fc_tp()

def benchmark_tp():

    irreps_in1 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_in2 = so13.Lorentz_Irreps('1x(1,1)')
    irreps_target = so13.Lorentz_Irreps('1x(2,2)')

    irreps_out,instruction = tp_out_irreps_with_instructions(irreps_in1, irreps_in2, irreps_target)

    tp = so13.TensorProduct(irreps_in1,
                    irreps_in2,
                    irreps_out,
                    instruction,
                    shared_weights=False,
                    internal_weights=False).to('cuda')

    x = irreps_in1.randn(100,-1,dtype=torch.complex128).to('cuda')
    y = irreps_in2.randn(100,-1,dtype=torch.complex128).to('cuda')
    x_dict = {(1,1) : torch.view_as_real(x.unsqueeze(-2)).permute(-1,0,1,2)}
    y_dict = {(1,1) : torch.view_as_real(y.unsqueeze(-2)).permute(-1,0,1,2)}

    cg_dict =  CGDict(4,device='cuda')

    t1 = benchmark.Timer(
    stmt="tp(x,y,torch.ones(100,1,dtype=torch.complex128,device='cuda'))",
    globals={'tp': tp,
             'x': x,
             'y': y,},
    label='tp_so13',
    sub_label='Implemented using einsum')

    t2 = benchmark.Timer(
    stmt="cg_product(cg_dict,x_dict,y_dict,maxdim=3)[(2,2)]",
    globals={'cg_product': cg_product,
             'cg_dict': cg_dict,
             'x_dict': x_dict,
             'y_dict': y_dict,},
    label='tp_cg',
    sub_label='Implemented using einsum')

    print('tp_so13:',t1.timeit(1000))

    print('tp_cg:',t2.timeit(1000))

benchmark_tp()
