import torch
from algorithm.JOBCD.utils_JOBCD import *
import numpy as np

def nonconvex_orth2d_quad_L0(H,B):
    V1, f1 = nonconvex_orth2d_quad_L0_R(H, B)
    V2, f2 = nonconvex_orth2d_quad_L0_F(H, B)

    if f1<f2:
        V = V1
    else:
        V = V2
    V = V.to(torch.float32)
    return V

def nonconvex_orth2d_quad_L0_R(Q,B):
    cof_a = B[0, 0] + B[1, 1]
    cof_b = B[0, 1] - B[1, 0]
    cof_c = 0.5 * Q[0, 0] + Q[0, 3] + 0.5 * Q[3, 3]
    cof_d = - Q[0, 1] + Q[0, 2] - Q[1, 3] + Q[2, 3]
    cof_e = 0.5 * Q[1, 1] - Q[1, 2] + 0.5 * Q[2, 2]

    [c, s, fmin] = nonconvex_quadratic_trigonometry_L0(cof_a, cof_b, cof_c, cof_d, cof_e)

    V = torch.tensor([[c, s],[-s, c]])
    return V, fmin

def nonconvex_orth2d_quad_L0_F(H,B):
    cof_c = 0.5 * H[0, 0] - 0.5 * H[0, 3] - 0.5 * H[3, 0] + 0.5 * H[3, 3]
    cof_e = 0.5 * H[1, 1] + 0.5 * H[1, 2] + 0.5 * H[2, 1] + 0.5 * H[2, 2]
    cof_d = (- 0.5 * H[0, 1] - 0.5 * H[0, 2] - 0.5 * H[1, 0] + 0.5 * H[1, 3]
             - 0.5 * H[2, 0] + 0.5 * H[2, 3] + 0.5 * H[3,1] + 0.5 * H[3, 2])
    cof_a = -B[0,0] + B[1,1]
    cof_b = B[1,0] + B[0,1]
    [c, s, fmin] = nonconvex_quadratic_trigonometry_L0(cof_a, cof_b, cof_c, cof_d, cof_e)

    V = torch.tensor([[-c, s],[s, c]])
    return V, fmin

def nonconvex_quadratic_trigonometry_L0(A,B,C,D,E):
    cos_list = torch.tensor([0,0,1,-1])
    sin_list = torch.tensor([1,-1,0,0])

    critical_points = get_all_critical_points(A,B,C,D,E).clone().detach()
    if torch.numel(critical_points) == 1:
        critical_points = critical_points.unsqueeze(dim=0)

    # case 1
    for iii in range(len(critical_points.unsqueeze(0))):
        t_case_1 = critical_points[iii]
        chi = 1 / torch.sqrt(1 + t_case_1 * t_case_1)
        cos_theta1 = chi
        sin_theta1 = t_case_1*chi
        if cos_theta1.dim() ==0:
            cos_theta1 = cos_theta1.unsqueeze(0)
        if sin_theta1.dim() ==0:
            sin_theta1 = sin_theta1.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta1])
        sin_list = torch.cat([sin_list, sin_theta1])

    # case 2
    for iii in range(len(critical_points)):
        t_case_2 = critical_points[iii]
        chi = 1 / torch.sqrt(1 + t_case_2 * t_case_2)
        cos_theta2 = -chi
        sin_theta2 = -t_case_2*chi
        if cos_theta2.dim() ==0:
            cos_theta2 = cos_theta2.unsqueeze(0)
        if sin_theta2.dim() ==0:
            sin_theta2 = sin_theta2.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta2])
        sin_list = torch.cat([sin_list, sin_theta2])


    fs = torch.zeros(len(cos_list))
    HandleObj2 = lambda c, s: A * c + B * s + C * c ** 2 + D * s * c + E * s * s
    for i in range(len(cos_list)):
        fs[i] = HandleObj2(cos_list[i], sin_list[i])

    fmin = torch.min(fs)
    j = torch.argmin(fs)
    c = cos_list[j]
    s = sin_list[j]
    return c, s, fmin

def get_all_critical_points(a,b,c,d,e):
    a = a.detach().numpy()
    b = b.detach().numpy()
    c = c.detach().numpy()
    d = d.detach().numpy()
    e = e.detach().numpy()

    w = c-e
    c0 = d * d - b * b
    c1 = 2 * b * a - 4 * d * w
    c2 = 4 * w * w - 2 * d * d - a * a - b * b
    c3 = 4 * w * d + 2 * a * b
    c4 = d * d - a * a

    ts = torch.tensor(np.roots([c4,c3,c2,c1,c0]))
    x = getreal(ts)
    if torch.numel(x) == 0:
        x = torch.tensor(0)
    return x