import torch
from UHstruct.distU.distU import *

def UltraE_score(distU, bh, bt, margin):
    S = -distU**2+bh+bt+margin
    return S

def UltraE_get_all_distU(eh,R,et,Ceh,Cet, config_yaml):
    p = config_yaml["datafeature"]["p"]
    beta = config_yaml["datafeature"]["beta"]
    k = Ceh.shape[1]

    distU = simple_dist_UltraE(eh, R, et, p, beta)

    CdistU = torch.zeros([k, 1])
    for kk in range(k):
        CdistU[kk] = simple_dist_UltraE(Ceh[:, kk], R, Cet[:, kk], p, beta)


    return distU, CdistU

def Sdist(A, B, beta, p):
    beta = torch.tensor(beta)
    temp = qdot(A, B, p) / beta
    if abs(temp) < 1:
        y = torch.sqrt(abs(beta)) * torch.acos(abs(temp))
    else:
        y = torch.sqrt(abs(beta)) * torch.acosh(abs(temp))
    return y