import torch

def orth(A):
    U, S, V = torch.svd(A)
    A = torch.matmul(U, V.T)
    return A

def exc_operator(A,B1,B2):
    A11 = torch.linalg.inv(A[B1[:, None], B1])
    A12 = -torch.mm(torch.linalg.inv(A[B1[:, None], B1]), A[B1[:, None], B2])
    A21 = torch.mm(A[B2[:, None], B1], torch.linalg.inv(A[B1[:, None], B1]))
    A22 = A[B2[:, None], B2]-torch.mm(A[B2[:, None], B1], torch.mm(torch.linalg.inv(A[B1[:, None], B1]), A[B1[:, None], B2]))
    A1 = torch.cat((A11, A12), dim=1)
    A2 = torch.cat((A21, A22), dim=1)
    B = torch.cat((A1, A2),dim=0)
    return B
