import torch
import math
from utils.Jorthogonal_test import Jtest
def UltraE_fval_obj(T, H, P, Q, config_yaml):
    N = H.shape[0]
    p = config_yaml["datafeature"]["p"]
    beta = config_yaml["datafeature"]["beta"]
    PH = H @ P
    P_PH = dproj(p, beta, PH)


    QPH = P_PH @ Q
    Y = Sdistnew(QPH, p, beta).reshape(N, N)

    fval = torch.norm(T - Y, p='fro')
    return fval

def dproj(p, beta, X):
    Proj_phi = lambda x:torch.cat([x[:, :p], beta * x[:, p:] / (torch.linalg.norm(x[:, p:], ord=2, dim=-1, keepdim=True)+1e-10)], dim=1)
    Proj_inversephi = lambda z: torch.cat([z[:, :p], torch.sqrt(abs(beta) + torch.linalg.norm(z[:, :p], ord=2, dim=-1, keepdim=True) ** 2) / beta * z[:, p:]], dim=1)
    z = Proj_inversephi(Proj_phi(X))
    return z

def simple_dist_UltraE(A, B, p, beta):
    distU13, distU14 = dist_UltraE(A, B, p, beta)
    distU = torch.min(torch.vstack([distU13, distU14]), dim=0)[0]
    return distU

def dist_UltraE(A, B, p, beta):
    distU13 = dist13(A, B, p, beta)
    distU14 = dist14(A, B, p, beta)
    return distU13, distU14

def dist13(x, y, p, beta):
    distU = Sdist(y, rhob_a(y, x, p), beta, p) + Sdist(rhob_a(y, x, p), x, beta, p)
    return distU


def dist14(x, y, p, beta):
    distU = Sdist(x, rhob_a(x, y, p), beta, p) + Sdist(rhob_a(x, y, p), y, beta, p)
    return distU

def rhob_a(a, b, p):
    norma_p = torch.linalg.norm(a[:,0:p], ord=2, dim=1,keepdim=True)
    normb_p = torch.linalg.norm(b[:,0:p], ord=2, dim=1,keepdim=True)
    bb = torch.cat([a[:,:p], b[:,:p].squeeze() * norma_p / normb_p],dim=1)
    return bb

def Sdist(A, B, beta, p):
    beta = torch.tensor(beta).to(A.device)
    temp = qdot(A, B, p) / beta
    mask = torch.abs(temp) < 1
    y = torch.zeros(temp.shape[0]).to(torch.float32).to(A.device)
    y[mask] = torch.sqrt(torch.abs(beta)) * torch.acos(torch.abs(temp[mask]))
    y[~mask] = torch.sqrt(torch.abs(beta)) * torch.acosh(torch.abs(temp[~mask]))
    return y

def qdot(A, B, p):
    AB = torch.mul(A, B)
    y = -torch.sum(AB[:, p:], dim=1) + torch.sum(AB[:, :p], dim=1)
    return y


def compute_distance_matrix(A, B, beta, p):
    K = qdot(A, B, p) / abs(beta) ** 2
    epsilon = 0.00001

    hyperbolic_indices = K < -1.0 - epsilon
    euclidean_indices = (K < -1.0 + epsilon) & (~hyperbolic_indices)
    positive_similarity = K >= 0.0
    spherical_indices = (~positive_similarity) & (~(K < -1.0 + epsilon))
    K[hyperbolic_indices] = beta * torch.acosh(-K[hyperbolic_indices])
    K[euclidean_indices] = beta * torch.abs(2.0 * (1.0 + K[euclidean_indices]))
    K[positive_similarity] = beta * (math.pi / 2 + K[positive_similarity])
    K[spherical_indices] = beta * torch.acos(-K[spherical_indices])
    return K

def Utest(X, p, c):
    err = torch.sum(qdot(X, X.t(), p) - c)
    return err


def create_row_combinations(tensorA):
    n, m = tensorA.shape
    tensorA_expanded_1 = tensorA.unsqueeze(1).expand(-1, n, -1)
    tensorA_expanded_2 = tensorA.repeat(n, 1).view(n, n, m)
    result_matrix_1 = tensorA_expanded_1.reshape(n * n, m)
    result_matrix_2 = tensorA_expanded_2.reshape(n * n, m)

    return result_matrix_1, result_matrix_2


def qdot_optimized(A, p):
    positive_sum = torch.einsum('ij,kj->ik', A[:, :p], A[:, :p])
    negative_sum = torch.einsum('ij,kj->ik', A[:, p:], A[:, p:])
    return -negative_sum + positive_sum


def Sdistnew(A, beta, p):
    beta = torch.tensor(beta).to(A.device)
    temp = qdot_optimized(A, p) / beta

    mask = torch.abs(temp) < 1
    y = torch.zeros(temp.shape).to(torch.float32).to(A.device)
    y[mask] = torch.sqrt(torch.abs(beta)) * torch.acos(torch.abs(temp[mask]))
    y[~mask] = torch.sqrt(torch.abs(beta)) * torch.acosh(torch.abs(temp[~mask]))
    return y.view(-1)