import torch
from algorithm.JOBCD.nonconvex_orth2d_quad_same_P import nonconvex_orth2d_quad_same_P
from algorithm.JOBCD.nonconvex_orth2d_quad_notsame_P import nonconvex_orth2d_quad_notsame_P

def Parallel_updateV(X, gradX, B, Lconst, theta, p):
    d = X.shape[0]
    J = torch.eye(d)
    J[p:, p:] = -1 * torch.eye(d - p)
    J = J.to(torch.float32).to(X.device)

    hessian = Lconst* torch.eye(d).to(torch.float32).to(X.device)

    UJ_num = J[B, B]

    '''Handle orthogonal-like cases'''
    B_equ = torch.where(torch.sum(UJ_num, dim=1) != 0)[0]
    nequ = B_equ.shape[0]
    Z_equ = X[B[B_equ], :].to(torch.float32)
    ZZ_equ = torch.bmm(Z_equ, Z_equ.transpose(1, 2)).to(torch.float32)

    indexes_expanded = B[B_equ][:, None, :]
    hessian_sub = hessian[indexes_expanded, indexes_expanded.transpose(1, 2)].squeeze()
    P_equ = (torch.bmm(gradX[B[B_equ], :], Z_equ.transpose(1, 2)) - torch.bmm(hessian_sub, ZZ_equ)
             - theta * torch.eye(2).to(X.device).expand(nequ, 2, 2).to(torch.float32))
    Q_equ = batch_kron(ZZ_equ, hessian_sub).transpose(1, 2) + torch.eye(4).to(X.device) * torch.sqrt(theta)

    V_equ = nonconvex_orth2d_quad_same_P(Q_equ, P_equ)  # Parallel

    '''Handle orthogonal-unlike cases'''
    B_neq = torch.where(torch.sum(UJ_num, dim=1) == 0)[0]
    nneq = B_neq.shape[0]
    Z_neq = X[B[B_neq], :].to(torch.float32)
    ZZ_neq = torch.bmm(Z_neq, Z_neq.transpose(1, 2)).to(torch.float32)

    indexes_expanded = B[B_neq][:, None, :]
    hessian_sub = hessian[indexes_expanded, indexes_expanded.transpose(1, 2)].squeeze()
    if nneq ==1:
        hessian_sub = hessian_sub[None,:]
    P_neq = (torch.bmm(gradX[B[B_neq], :], Z_neq.transpose(1, 2)) - torch.bmm(hessian_sub, ZZ_neq)
             - theta * torch.eye(2).to(X.device).expand(nneq, 2, 2).to(torch.float32))
    Q_neq = batch_kron(ZZ_neq, hessian_sub).transpose(1, 2) + torch.eye(4).to(X.device) * torch.sqrt(theta)

    V_neq = nonconvex_orth2d_quad_notsame_P(Q_neq, P_neq)  # Parallel
    indices = torch.mean(abs(V_neq), dim=(1, 2)) > 0.7
    V_neq[indices] = torch.eye(2).to(X.device)

    '''use V to update X'''
    indices_equ = B[B_equ].flatten()
    indices_neq = B[B_neq].flatten()
    expanded_X_equ = X[indices_equ].view(-1, 2, d)
    expanded_X_neq = X[indices_neq].view(-1, 2, d)
    result_equ = torch.matmul(V_equ, expanded_X_equ)
    result_neq = torch.matmul(V_neq, expanded_X_neq)
    X[indices_equ] = result_equ.view(-1, d)
    X[indices_neq] = result_neq.view(-1, d)
    return X

def batch_kron(A, B):
    siz1, siz2 = A.size(1), A.size(2)
    siz3, siz4 = B.size(1), B.size(2)
    result = A.unsqueeze(2).unsqueeze(4) * B.unsqueeze(1).unsqueeze(3)
    result = result.reshape(A.size(0), siz1*siz3, siz2*siz4)
    return result