import torch
import numpy as np
from algorithm.JOBCD.utils_JOBCD import *
def nonconvex_orth2d_quad_notsame(H,B):
    V1, f1 = nonconvex_orth2d_quad_L0_aaequ(H, B)
    V2, f2 = nonconvex_orth2d_quad_L0_aanequ(H, B)

    if (f1 < f2):
        V = V1
    else:
        V = V2
    V = V.to(torch.float32)
    return V

def nonconvex_orth2d_quad_L0_aaequ(Q, B):
    cof_a = B[0, 0] + B[1, 1]
    cof_b = B[0, 1] + B[1, 0]
    cof_c = 0.5 * (Q[0, 0] + Q[0, 3] + Q[3, 0] + Q[3, 3])
    cof_d = 0.5 * (Q[1, 0] + Q[2, 0] + Q[0, 1] + Q[3, 1] + Q[0, 2] + Q[3, 2] + Q[1, 3] + Q[2, 3])
    cof_e = 0.5 * (Q[1, 1] + Q[2, 1] + Q[1, 2] + Q[2, 2])

    c_pos, s_pos, fmin_pos = nonconvex_quadratic_trigonometry_L0_xy_pos(cof_a, cof_b, cof_c, cof_d, cof_e)
    c_neg, s_neg, fmin_neg = nonconvex_quadratic_trigonometry_L0_xy_neg(cof_a, cof_b, cof_c, cof_d, cof_e)
    V_pos = torch.tensor([[c_pos, s_pos], [s_pos, c_pos]])
    V_neg = torch.tensor([[c_neg, s_neg], [s_neg, c_neg]])
    if fmin_pos < fmin_neg:
        V = V_pos
        fmin = fmin_pos
    else:
        V = V_neg
        fmin = fmin_neg
    return V, fmin

def nonconvex_orth2d_quad_L0_aanequ(Q, B):
    cof_a = -B[0, 0] + B[1, 1]
    cof_b = -B[0, 1] + B[1, 0]
    cof_c = 0.5 * (Q[0, 0] - Q[0, 3] - Q[3, 0] + Q[3, 3])
    cof_d = 0.5 * (-Q[1, 0] + Q[2, 0] - Q[0, 1] + Q[3, 1] + Q[0, 2] - Q[3, 2] + Q[1, 3] - Q[2, 3])
    cof_e = 0.5 * (Q[1, 1] - Q[2, 1] - Q[1, 2] + Q[2, 2])

    c_pos, s_pos, fmin_pos = nonconvex_quadratic_trigonometry_L0_xy_pos(cof_a, cof_b, cof_c, cof_d, cof_e)
    c_neg, s_neg, fmin_neg = nonconvex_quadratic_trigonometry_L0_xy_neg(cof_a, cof_b, cof_c, cof_d, cof_e)
    V_pos = torch.tensor([[-c_pos, -s_pos], [s_pos, c_pos]])
    V_neg = torch.tensor([[-c_neg, -s_neg], [s_neg, c_neg]])
    if fmin_pos < fmin_neg:
        V = V_pos
        fmin = fmin_pos
    else:
        V = V_neg
        fmin = fmin_neg

    return V, fmin

def nonconvex_quadratic_trigonometry_L0_xy_pos(A, B, C, D, E):
    cos_list = torch.tensor([1, -1])
    sin_list = torch.tensor([0, 0])
    critical_points = get_all_critical_points_xy_pos(A, B, C, D, E).clone().detach()

    if torch.numel(critical_points) == 1:
        critical_points = critical_points.unsqueeze(dim=0)

    for iii in range(len(critical_points.unsqueeze(0))):
        t_case_1 = critical_points[iii]
        if abs(t_case_1) > 1:
            continue
        chi = 1 / torch.sqrt(1 - t_case_1 * t_case_1)
        cos_theta1 = chi
        sin_theta1 = t_case_1 * chi
        if cos_theta1.dim() ==0:
            cos_theta1 = cos_theta1.unsqueeze(0)
        if sin_theta1.dim() ==0:
            sin_theta1 = sin_theta1.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta1])
        sin_list = torch.cat([sin_list, sin_theta1])

    # case 2
    for iii in range(len(critical_points)):
        t_case_2 = critical_points[iii]
        if abs(t_case_2) > 1:
            continue
        chi = 1 / torch.sqrt(1 - t_case_2 * t_case_2)
        cos_theta2 = -chi
        sin_theta2 = -t_case_2 * chi
        if cos_theta2.dim() ==0:
            cos_theta2 = cos_theta2.unsqueeze(0)
        if sin_theta2.dim() ==0:
            sin_theta2 = sin_theta2.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta2])
        sin_list = torch.cat([sin_list, sin_theta2])

    fs = torch.zeros(len(cos_list))
    HandleObj2 = lambda c, s: A * c + B * s + C * c ** 2 + D * s * c + E * s * s
    for i in range(len(cos_list)):
        fs[i] = HandleObj2(cos_list[i], sin_list[i])

    fmin = torch.min(fs)
    j = torch.argmin(fs)
    c = cos_list[j]
    s = sin_list[j]
    return c, s, fmin

def nonconvex_quadratic_trigonometry_L0_xy_neg(A, B, C, D, E):
    cos_list = torch.tensor([1, -1])
    sin_list = torch.tensor([0, 0])
    critical_points = get_all_critical_points_xy_neg(A, B, C, D, E).clone().detach()

    if torch.numel(critical_points) == 1:
        critical_points = critical_points.unsqueeze(dim=0)

    for iii in range(len(critical_points.unsqueeze(0))):
        t_case_1 = critical_points[iii]
        if abs(t_case_1) > 1:
            continue
        chi = 1 / torch.sqrt(1 - t_case_1 * t_case_1)
        cos_theta1 = -chi
        sin_theta1 = t_case_1 * chi
        if cos_theta1.dim() ==0:
            cos_theta1 = cos_theta1.unsqueeze(0)
        if sin_theta1.dim() ==0:
            sin_theta1 = sin_theta1.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta1])
        sin_list = torch.cat([sin_list, sin_theta1])

    # case 2
    for iii in range(len(critical_points)):
        t_case_2 = critical_points[iii]
        if abs(t_case_2) > 1:
            continue
        chi = 1 / torch.sqrt(1 - t_case_2 * t_case_2)
        cos_theta2 = chi
        sin_theta2 = -t_case_2 * chi
        if cos_theta2.dim() ==0:
            cos_theta2 = cos_theta2.unsqueeze(0)
        if sin_theta2.dim() ==0:
            sin_theta2 = sin_theta2.unsqueeze(0)
        cos_list = torch.cat([cos_list, cos_theta2])
        sin_list = torch.cat([sin_list, sin_theta2])

    fs = torch.zeros(len(cos_list))
    HandleObj2 = lambda c, s: A * c + B * s + C * c ** 2 + D * s * c + E * s * s
    for i in range(len(cos_list)):
        fs[i] = HandleObj2(cos_list[i], sin_list[i])

    fmin = torch.min(fs)
    j = torch.argmin(fs)
    c = cos_list[j]
    s = sin_list[j]
    return c, s, fmin

def get_all_critical_points_xy_pos(a, b, c, d, e):
    a = a.detach().numpy()
    b = b.detach().numpy()
    c = c.detach().numpy()
    d = d.detach().numpy()
    e = e.detach().numpy()

    w = c + e
    c0 = d * d - b * b
    c1 = 4 * d * w - 2 * a * b
    c2 = 4 * w * w + 2 * d * d - a * a + b * b
    c3 = 4 * w * d + 2 * a * b
    c4 = d * d + a * a

    try:
        ts = torch.tensor(np.roots([c4,c3,c2,c1,c0]))
    except:
        print([c4,c3,c2,c1,c0])
    x = getreal(ts)
    if torch.numel(x) == 0:
        x = torch.tensor(0)
    return x

def get_all_critical_points_xy_neg(a, b, c, d, e):
    a = a.detach().numpy()
    b = b.detach().numpy()
    c = c.detach().numpy()
    d = d.detach().numpy()
    e = e.detach().numpy()

    w = c + e
    c0 = d * d - b * b
    c1 = -4 * d * w + 2 * a * b
    c2 = 4 * w * w + 2 * d * d - a * a + b * b
    c3 = -4 * w * d - 2 * a * b
    c4 = d * d + a * a

    ts = torch.tensor(np.roots([c4,c3,c2,c1,c0]))
    x = getreal(ts)
    if torch.numel(x) == 0:
        x = torch.tensor(0)
    return x