import torch


# [batch,k,dim_features,n,n],k=2.
def check_valid(model):
    test1, test2, test3 = True, True, True
    model = model.cuda()
    if test1:
        dim = 5
        batch = 20
        A = torch.randn(batch, dim, dim).cuda()
        perm = torch.randperm(dim)
        P = torch.eye(dim)[:, perm].cuda()
        B = P.T @ A @ P
        out1 = model.model(torch.stack([A, A]).view(-1, 2, batch, dim, dim),None) @ P
        out2 = model.model(torch.stack([A, B]).view(-1, 2, batch, dim, dim),None)
        print(f"The norm of the full model difference is:{torch.norm(out2 - out1)}")
    # Test 2- check equivariance to second graph.
    if test2:
        dim = 20
        A = torch.randn(batch, dim, dim).cuda()
        perm = torch.randperm(dim)
        P = torch.eye(dim)[:, perm].cuda()
        B = P.T @ A @ P
        out1 = P.T @ model.model(torch.stack([A, A]).view(-1, 2, batch, dim, dim),None)
        out2 = model.model(torch.stack([B, A]).view(-1, 2, batch, dim, dim),None)
        print(f"The norm of the full model difference is:{torch.norm(out2 - out1)}")
    # Test 3 - check permutation of graphs is equivariant.
    if test3:
        dim = 5
        A = torch.randn(batch, dim, dim).cuda()
        perm = torch.randperm(dim)
        P = torch.eye(dim)[:, perm].cuda()
        B = P.T @ A @ P
        out1 = model.model(torch.stack([A, B]).view(-1, 2,batch, dim, dim),None)
        out2 = model.model(torch.stack([B, A]).view(-1, 2, batch, dim, dim),None)
        print(torch.norm(out2 - out1.transpose(-1, -2)))

if False:
    config = EasyDict(lr=0.0001, wd=0.0)
    model = LightningModel(config=config).to('cuda')
    check_valid(model)
