from algorithm.JOBCD.utils_JOBCD import *
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import os

def nonconvex_orth2d_quad_notsame_P(H, B):
    V1, f1 = nonconvex_orth2d_quad_L0_aaequ(H, B)
    V2, f2 = nonconvex_orth2d_quad_L0_aanequ(H, B)
    f = torch.stack([f1, f2])
    indexf = torch.argmin(f, dim=0)
    indexv1 = indexf == 0
    indexv2 = indexf == 1
    V = torch.zeros(H.shape[0], 2, 2).to(H.device)
    V[indexv1, :, :] = V1[indexv1, :, :]
    V[indexv2, :, :] = V2[indexv2, :, :]
    return V

def nonconvex_orth2d_quad_L0_aaequ(Q,B):
    cof_a = B[:, 0, 0] + B[:, 1, 1]
    cof_b = B[:, 0, 1] + B[:, 1, 0]
    cof_c = 0.5 * (Q[:, 0, 0] + Q[:, 0, 3] + Q[:, 3, 0] + Q[:, 3, 3])
    cof_d = 0.5 * (Q[:, 1, 0] + Q[:, 2, 0] + Q[:, 0, 1] + Q[:, 3, 1] + Q[:, 0, 2] + Q[:, 3, 2] + Q[:, 1, 3] + Q[:, 2, 3])
    cof_e = 0.5 * (Q[:, 1, 1] + Q[:, 2, 1] + Q[:, 1, 2] + Q[:, 2, 2])

    c_pos, s_pos, fmin_pos = nonconvex_quadratic_trigonometry_L0_xy_pos(cof_a, cof_b, cof_c, cof_d, cof_e)
    c_neg, s_neg, fmin_neg = nonconvex_quadratic_trigonometry_L0_xy_neg(cof_a, cof_b, cof_c, cof_d, cof_e)

    V_pos = torch.stack((torch.stack((c_pos, s_pos), dim=-1),
                     torch.stack((s_pos, c_pos), dim=-1)), dim=-2)
    V_neg = torch.stack((torch.stack((c_neg, s_neg), dim=-1),
                     torch.stack((s_neg, c_neg), dim=-1)), dim=-2)

    f = torch.stack([fmin_pos, fmin_neg])
    fmin, indexf = torch.min(f, dim=0)
    indexv1 = indexf == 0
    indexv2 = indexf == 1
    V = torch.zeros(B.shape[0], 2, 2).to(Q.device)
    V[indexv1, :, :] = V_pos[indexv1, :, :]
    V[indexv2, :, :] = V_neg[indexv2, :, :]
    return V, fmin

def nonconvex_orth2d_quad_L0_aanequ(Q,B):
    cof_a = -B[:, 0, 0] + B[:, 1, 1]
    cof_b = -B[:, 0, 1] + B[:, 1, 0]
    cof_c = 0.5 * (Q[:, 0, 0] - Q[:, 0, 3] - Q[:, 3, 0] + Q[:, 3, 3])
    cof_d = 0.5 * (-Q[:, 1, 0] + Q[:, 2, 0] - Q[:, 0, 1] + Q[:, 3, 1] + Q[:, 0, 2] - Q[:, 3, 2] + Q[:, 1, 3] - Q[:, 2, 3])
    cof_e = 0.5 * (Q[:, 1, 1] - Q[:, 2, 1] - Q[:, 1, 2] + Q[:, 2, 2])
    c_pos, s_pos, fmin_pos = nonconvex_quadratic_trigonometry_L0_xy_pos(cof_a, cof_b, cof_c, cof_d, cof_e)
    c_neg, s_neg, fmin_neg = nonconvex_quadratic_trigonometry_L0_xy_neg(cof_a, cof_b, cof_c, cof_d, cof_e)


    V_pos = torch.stack((torch.stack((-c_pos, -s_pos), dim=-1),
                     torch.stack((s_pos, c_pos), dim=-1)), dim=-2)
    V_neg = torch.stack((torch.stack((-c_neg, -s_neg), dim=-1),
                     torch.stack((s_neg, c_neg), dim=-1)), dim=-2)

    f = torch.stack([fmin_pos, fmin_neg])
    fmin, indexf = torch.min(f, dim=0)
    indexv1 = indexf == 0
    indexv2 = indexf == 1
    V = torch.zeros(B.shape[0], 2, 2).to(Q.device)
    V[indexv1, :, :] = V_pos[indexv1, :, :]
    V[indexv2, :, :] = V_neg[indexv2, :, :]
    return V, fmin

def nonconvex_quadratic_trigonometry_L0_xy_pos(A,B,C,D,E):
    nequ = A.shape[0]
    cos_list = torch.tensor([1, -1]).repeat(nequ, 1).to(A.device)
    sin_list = torch.tensor([0, 0]).repeat(nequ, 1).to(A.device)

    critical_points = get_all_critical_points_xy_pos(A, B, C, D, E).clone().detach()

    cos_theta1 = 1 / torch.sqrt(1 - critical_points * critical_points)
    sin_theta1 = critical_points / torch.sqrt(1 - critical_points * critical_points)
    cos_theta2 = -1 / torch.sqrt(1 - critical_points * critical_points)
    sin_theta2 = -critical_points / torch.sqrt(1 - critical_points * critical_points)
    cos_list = torch.cat([cos_list, cos_theta1], dim=1)
    sin_list = torch.cat([sin_list, sin_theta1], dim=1)
    cos_list = torch.cat([cos_list, cos_theta2], dim=1)
    sin_list = torch.cat([sin_list, sin_theta2], dim=1)

    sin_list[torch.isnan(cos_list)] = 0
    cos_list[torch.isnan(cos_list)] = 1

    cos_list[torch.isnan(sin_list)] = 1
    sin_list[torch.isnan(sin_list)] = 0

    HandleObj2 = lambda c, s: A.unsqueeze(1) * c + B.unsqueeze(1) * s + C.unsqueeze(1) * c ** 2 + D.unsqueeze(1) * s * c + E.unsqueeze(1) * s * s
    fs = HandleObj2(cos_list, sin_list)

    fmin, j = torch.min(fs, dim=1)
    c = cos_list[torch.arange(0, nequ), j].to(torch.float32)
    s = sin_list[torch.arange(0, nequ), j].to(torch.float32)
    return c, s, fmin

def nonconvex_quadratic_trigonometry_L0_xy_neg(A,B,C,D,E):
    nequ = A.shape[0]
    cos_list = torch.tensor([1, -1]).repeat(nequ, 1).to(A.device)
    sin_list = torch.tensor([0, 0]).repeat(nequ, 1).to(A.device)

    critical_points = get_all_critical_points_xy_neg(A, B, C, D, E).clone().detach()

    cos_theta1 = -1 / torch.sqrt(1 - critical_points * critical_points)
    sin_theta1 = critical_points / torch.sqrt(1 - critical_points * critical_points)
    cos_theta2 = 1 / torch.sqrt(1 - critical_points * critical_points)
    sin_theta2 = -critical_points / torch.sqrt(1 - critical_points * critical_points)
    cos_list = torch.cat([cos_list, cos_theta1], dim=1)
    sin_list = torch.cat([sin_list, sin_theta1], dim=1)
    cos_list = torch.cat([cos_list, cos_theta2], dim=1)
    sin_list = torch.cat([sin_list, sin_theta2], dim=1)

    sin_list[torch.isnan(cos_list)] = 0
    cos_list[torch.isnan(cos_list)] = 1

    cos_list[torch.isnan(sin_list)] = 1
    sin_list[torch.isnan(sin_list)] = 0

    HandleObj2 = lambda c, s: A.unsqueeze(1) * c + B.unsqueeze(1) * s + C.unsqueeze(1) * c ** 2 + D.unsqueeze(1) * s * c + E.unsqueeze(1) * s * s
    fs = HandleObj2(cos_list, sin_list)

    fmin, j = torch.min(fs, dim=1)
    c = cos_list[torch.arange(0, nequ), j].to(torch.float32)
    s = sin_list[torch.arange(0, nequ), j].to(torch.float32)
    return c, s, fmin

def get_all_critical_points_xy_pos(a, b, c, d, e):
    device = a.device
    a = a.cpu().detach().numpy()
    b = b.cpu().detach().numpy()
    c = c.cpu().detach().numpy()
    d = d.cpu().detach().numpy()
    e = e.cpu().detach().numpy()

    w = c + e
    c0 = d * d - b * b
    c1 = 4 * d * w - 2 * a * b
    c2 = 4 * w * w + 2 * d * d - a * a + b * b
    c3 = 4 * w * d + 2 * a * b
    c4 = d * d + a * a


    ts = np.zeros((len(a), 4),dtype=complex)
    for i in range(len(a)):
        try:
            ts[i, :] = np.roots([c4[i], c3[i], c2[i], c1[i], c0[i]])
        except:
            continue

    x = getreal(torch.tensor(ts))
    return x.to(device)

def get_all_critical_points_xy_neg(a, b, c, d, e):
    device = a.device
    a = a.cpu().detach().numpy()
    b = b.cpu().detach().numpy()
    c = c.cpu().detach().numpy()
    d = d.cpu().detach().numpy()
    e = e.cpu().detach().numpy()

    w = c + e
    c0 = d * d - b * b
    c1 = -4 * d * w + 2 * a * b
    c2 = 4 * w * w + 2 * d * d - a * a + b * b
    c3 = -4 * w * d - 2 * a * b
    c4 = d * d + a * a


    ts = np.zeros((len(a), 4),dtype=complex)
    for i in range(len(a)):
        try:
            ts[i, :] = np.roots([c4[i], c3[i], c2[i], c1[i], c0[i]])
        except:
            continue

    x = getreal(torch.tensor(ts))
    return x.to(device)

def find_roots(coeffs):
    return np.roots(coeffs)

