from qpsolvers import solve_qp
import numpy as np
import torch

from utils import cprint

EPS = 1e-8

def flatGrad(y, x, retain_graph=False, create_graph=False):
    if create_graph:
        retain_graph = True
    g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
    g = torch.cat([t.view(-1) for t in g])
    return g

def aggregation(objectives_tensor, preferences_tensor, actor, device, c=1.0):
    grad_list = []
    phi_grad_list = []
    for obj_idx in range(len(objectives_tensor)):
        g_tensor = flatGrad(
            objectives_tensor[obj_idx], actor.parameters(), 
            retain_graph=True)
        grad_list.append(g_tensor)
        phi_grad_list.append(g_tensor*preferences_tensor[obj_idx])

    with torch.no_grad():
        # convert preferences to numpy
        preferences = preferences_tensor.detach().cpu().numpy().astype(np.float64)

        # calculate phi
        phi_grad = torch.stack(phi_grad_list).sum(dim=0)
        phi_grad = phi_grad.detach().cpu().numpy().astype(np.float64)
        phi = c*np.linalg.norm(phi_grad)

        # calculate S mat
        B_mat = torch.stack(grad_list) # n_objs x n_params
        S_tensor = B_mat@B_mat.T # n_objs x n_objs
        S_mat = S_tensor.detach().cpu().numpy().astype(np.float64)
        S_mat = 0.5*(S_mat + S_mat.T)
        # ===== to ensure S_mat is invertible ===== #
        min_eig_val = min(0.0, np.min(np.linalg.eigvals(S_mat)))
        S_mat += np.eye(S_mat.shape[0])*(-min_eig_val + EPS)
        # ========================================= #

        # solve QP problem
        P_mat = 2*phi*S_mat
        q_vec = S_mat@preferences
        A_mat = np.ones((1, len(preferences)))
        b_vec = np.array([1.0])
        lb = np.zeros(len(preferences))
        try:
            lam_vector = solve_qp(
                P=P_mat, q=q_vec, A=A_mat, b=b_vec, lb=lb)
            assert lam_vector is not None
            lam_tensor = torch.tensor(lam_vector, dtype=torch.float32, device=device)
            new_grad_tensor = (lam_tensor.view((-1, 1))*B_mat).sum(dim=0)
            new_grad_norm = torch.norm(new_grad_tensor)
            new_preference_tensor = preferences_tensor + phi*lam_tensor/new_grad_norm
            new_preference_tensor /= torch.sum(new_preference_tensor)
        except Exception as e:
            cprint('solve_qp failed: {}'.format(e), 'red')
            new_preference_tensor = preferences_tensor

    return new_preference_tensor@objectives_tensor
