import os
from algorithm.JOBCD.utils_JOBCD import *
import numpy as np

def nonconvex_orth2d_quad_same_P(H, B):
    V1, f1 = nonconvex_orth2d_quad_L0_R(H, B)
    V2, f2 = nonconvex_orth2d_quad_L0_F(H, B)
    f = torch.stack([f1, f2])
    indexf = torch.argmin(f, dim=0)
    indexv1 = indexf == 0
    indexv2 = indexf == 1
    V = torch.zeros(H.shape[0], 2, 2).to(H.device)
    V[indexv1, :, :] = V1[indexv1, :, :]
    V[indexv2, :, :] = V2[indexv2, :, :]
    return V

def nonconvex_orth2d_quad_L0_R(Q, B):
    cof_a = B[:, 0, 0] + B[:, 1, 1]
    cof_b = B[:, 0, 1] - B[:, 1, 0]
    cof_c = 0.5 * Q[:, 0, 0] + Q[:, 0, 3] + 0.5 * Q[:, 3, 3]
    cof_d = - Q[:, 0, 1] + Q[:, 0, 2] - Q[:, 1, 3] + Q[:, 2, 3]
    cof_e = 0.5 * Q[:, 1, 1] - Q[:, 1, 2] + 0.5 * Q[:, 2, 2]

    [c, s, fmin] = nonconvex_quadratic_trigonometry_L0(cof_a, cof_b, cof_c, cof_d, cof_e)
    V = torch.stack((torch.stack((c, s), dim=-1),
                     torch.stack((-s, c), dim=-1)), dim=-2)

    return V, fmin

def nonconvex_orth2d_quad_L0_F(H, B):
    cof_c = 0.5 * H[:, 0, 0] - 0.5 * H[:, 0, 3] - 0.5 * H[:, 3, 0] + 0.5 * H[:, 3, 3]
    cof_e = 0.5 * H[:, 1, 1] + 0.5 * H[:, 1, 2] + 0.5 * H[:, 2, 1] + 0.5 * H[:, 2, 2]
    cof_d = (- 0.5 * H[:, 0, 1] - 0.5 * H[:, 0, 2] - 0.5 * H[:, 1, 0] + 0.5 * H[:, 1, 3]
             - 0.5 * H[:, 2, 0] + 0.5 * H[:, 2, 3] + 0.5 * H[:, 3,1] + 0.5 * H[:, 3, 2])
    cof_a = -B[:, 0, 0] + B[:, 1, 1]
    cof_b = B[:, 1, 0] + B[:, 0, 1]
    [c, s, fmin] = nonconvex_quadratic_trigonometry_L0(cof_a, cof_b, cof_c, cof_d, cof_e)

    V = torch.stack((torch.stack((-c, s), dim=-1),
                     torch.stack((s, c), dim=-1)), dim=-2)

    return V, fmin

def nonconvex_quadratic_trigonometry_L0(A, B, C, D, E):
    nequ = A.shape[0]
    cos_list = torch.tensor([0, 0, 1, -1]).repeat(nequ, 1).to(A.device)
    sin_list = torch.tensor([1, -1, 0, 0]).repeat(nequ, 1).to(A.device)

    critical_points = get_all_critical_points(A, B, C, D, E).clone().detach()
    cos_theta1 = 1 / torch.sqrt(1 + critical_points * critical_points)
    sin_theta1 = critical_points / torch.sqrt(1 + critical_points * critical_points)

    cos_theta2 = -1 / torch.sqrt(1 + critical_points * critical_points)
    sin_theta2 = -critical_points / torch.sqrt(1 + critical_points * critical_points)
    cos_list = torch.cat([cos_list, cos_theta1], dim=1)
    sin_list = torch.cat([sin_list, sin_theta1], dim=1)
    cos_list = torch.cat([cos_list, cos_theta2], dim=1)
    sin_list = torch.cat([sin_list, sin_theta2], dim=1)

    HandleObj2 = lambda c, s: A.unsqueeze(1) * c + B.unsqueeze(1) * s + C.unsqueeze(1) * c ** 2 + D.unsqueeze(1) * s * c + E.unsqueeze(1) * s * s
    fs = HandleObj2(cos_list, sin_list)

    fmin, j = torch.min(fs, dim=1)
    c = cos_list[torch.arange(0, nequ), j].to(torch.float32)
    s = sin_list[torch.arange(0, nequ), j].to(torch.float32)
    return c, s, fmin

def get_all_critical_points(a,b,c,d,e):
    device = a.device
    a = a.cpu().detach().numpy()
    b = b.cpu().detach().numpy()
    c = c.cpu().detach().numpy()
    d = d.cpu().detach().numpy()
    e = e.cpu().detach().numpy()

    w = c-e
    c0 = d * d - b * b
    c1 = 2 * b * a - 4 * d * w
    c2 = 4 * w * w - 2 * d * d - a * a - b * b
    c3 = 4 * w * d + 2 * a * b
    c4 = d * d - a * a

    ts = np.zeros((len(a), 4),dtype=complex)
    for i in range(len(a)):
        try:
            ts[i, :] = np.roots([c4[i], c3[i], c2[i], c1[i], c0[i]])
        except:
            continue


    x = getreal(torch.tensor(ts))
    return x.to(device)

def find_roots(coeffs):
    return np.roots(coeffs)
