import torch
import numpy as np
from utils.draw_real import draw_real


def U_proj(x, p):
    handle_obj = lambda a: torch.linalg.norm(x - a, ord=2) ** 2
    Ux = torch.zeros(x.shape)
    m = torch.linalg.norm(x[0:p], 2) ** 2
    n = torch.linalg.norm(x[p:], 2) ** 2
    a = 1 / 4
    b = 0
    c = m - n - 2
    d = -4 * m - 4 * n
    e = 4 * m - 4 * n + 4
    lambdaa = draw_real(torch.tensor(np.roots([a, b, c, d, e])))
    fobj_list = torch.zeros(len(lambdaa))
    for i in range(len(lambdaa)):
        Ux_temp = torch.zeros(len(x), 1)
        Ux_temp[0:p] = 2 * x[0:p] / (2 + lambdaa[i])
        Ux_temp[p:] = 2 * x[p:] / (2 - lambdaa[i])
        fobj_list[i] = handle_obj(Ux_temp)
    index = torch.argmin(fobj_list)
    if torch.numel(index)>1:
        index = index[0]
    lambdaa = lambdaa[index]
    Ux[0:p] = 2 * x[0:p] / (2 + lambdaa)
    Ux[p:] = 2 * x[p:] / (2 - lambdaa)
    return Ux
