import torch

def mdot(A,B):
    C = torch.sum(torch.mul(A, B))
    return C

def nnz3(x,eps):
    x[abs(x)<eps]=0
    r = torch.nonzero(x).size(0)
    return r

def getreal(x):
    if torch.is_complex(x):
        real_parts = torch.real(x)
        imag_parts = torch.imag(x)
        real_parts[imag_parts >= 1] = 0
        x = real_parts
    return x


