import torch
import math

def UltraE_test_hits_rank(triple, vec_entity, vec_relation, vec_bias, config_yaml):
    p = config_yaml["datafeature"]["p"]
    beta = config_yaml["datafeature"]["beta"]
    knum = config_yaml["datafeature"]["knum"]
    margin = config_yaml["datafeature"]["margin"]

    vec_entity = dproj(p, beta, vec_entity)

    n = triple.size(0)
    entity_num = vec_entity.size(0)

    eh = vec_entity[triple[:, 0],:]
    bh = vec_bias[triple[:, 0]]
    R = vec_relation[triple[:, 2],:,:]
    list_entity = torch.arange(entity_num).repeat(n, 1).to(R.device)

    et = vec_entity[list_entity]
    bt = vec_bias[list_entity]
    Udist = simple_dist_UltraE(eh, R, et, p, beta)
    Score = -Udist ** 2 + bh.unsqueeze(1) + bt + margin
    Score[torch.arange(n), triple[:n, 0]] = -1e10
    _, t_index = torch.sort(Score, descending=True,dim=1)
    rank_test = torch.where(t_index[:n, :] == triple[:n, 1].unsqueeze(1))[1] + 1

    hits1 = config_yaml["datafeature"]["hits1"]
    hits2 = config_yaml["datafeature"]["hits2"]
    hits3 = config_yaml["datafeature"]["hits3"]
    hits = torch.zeros(3)
    hits[0] = torch.sum(rank_test <= hits1) / n
    hits[1] = torch.sum(rank_test <= hits2) / n
    hits[2] = torch.sum(rank_test <= hits3) / n
    MRR = torch.mean(1.0 / rank_test)

    return hits, MRR

def dproj(p, beta, X):
    Proj_phi = lambda x:torch.cat([x[:, :p], beta * x[:, p:] / torch.linalg.norm(x[:, p:], ord=2, dim=-1, keepdim=True)], dim=1)
    Proj_inversephi = lambda z: torch.cat([z[:, :p], torch.sqrt(abs(beta) + torch.linalg.norm(z[:, :p], ord=2, dim=-1, keepdim=True) ** 2) / beta * z[:, p:]], dim=1)
    z = Proj_inversephi(Proj_phi(X))
    return z

def simple_dist_UltraE(eh, R, et, p, beta):
    Reh = torch.sum(R * eh.unsqueeze(1), dim=2)

    distU13, distU14 = dist_UltraE(Reh, et, p, beta)
    distU = torch.min(torch.cat([distU13.unsqueeze(2), distU14.unsqueeze(2)],dim=2),dim=2)[0]
    return distU

def dist_UltraE(Reh, et, p, beta):
    distU13 = dist13(Reh, et, p, beta)
    distU14 = dist14(Reh, et, p, beta)
    return distU13, distU14

def dist13(x, y, p, beta):
    x = x.unsqueeze(2).repeat(1, 1, y.shape[1]).transpose(1,2)
    distU = Sdist(y, rhob_a(y, x, p), beta, p) + Sdist(rhob_a(y, x, p), x, beta, p)
    return distU


def dist14(x, y, p, beta):
    x = x.unsqueeze(2).repeat(1, 1, y.shape[1]).transpose(1, 2)
    distU = Sdist(x, rhob_a(x, y, p), beta, p) + Sdist(rhob_a(x, y, p), y, beta, p)
    return distU

def rhob_a(a, b, p):
    norma_p = torch.linalg.norm(a[:,:,0:p], ord=2, dim=2)
    normb_p = torch.linalg.norm(b[:,:,0:p], ord=2, dim=2)
    bb = torch.cat([a[:,:,:p], b[:,:,p:] * norma_p.unsqueeze(2) / normb_p.unsqueeze(2)],dim=2)
    return bb

def Sdist(A, B, beta, p):
    beta = torch.tensor(beta).to(A.device)
    AB = torch.mul(A, B)
    temp = (torch.sum(AB[:, :, :p], dim=2) - torch.sum(AB[:, :, p:], dim=2)) / beta


    mask = torch.abs(temp) < 1
    y = torch.zeros(temp.shape[0],temp.shape[1]).to(A.device)
    y[mask] = torch.sqrt(torch.abs(beta)) * torch.acos(torch.abs(temp[mask]))
    y[~mask] = torch.sqrt(torch.abs(beta)) * torch.acosh(torch.abs(temp[~mask]))
    return y

def qdot(A, B, p):
    AB = torch.mul(A, B.t())
    y = -torch.sum(AB[:, p:], dim=1) + torch.sum(AB[:, :p], dim=1)
    return y

def UltraE_score(distU, bh, bt, margin):
    S = -distU**2+bh+bt+margin
    return S

def Utest(X, p, c):
    err = torch.sum(qdot(X, X.t(), p) - c)
    return err