from utils import *
from scipy.optimize import minimize_scalar

def pcgrad(grads):
    g1 = grads[:, 0]
    g2 = grads[:, 1]
    g11 = g1.dot(g1).item()
    g12 = g1.dot(g2).item()
    g22 = g2.dot(g2).item()
    if g12 < 0:
        return ((1 - g12 / g11) * g1 + (1 - g12 / g22) * g2) / 2
    else:
        return (g1 + g2) / 2


def mgd(grads):
    g1 = grads[:, 0]
    g2 = grads[:, 1]

    g11 = g1.dot(g1).item()
    g12 = g1.dot(g2).item()
    g22 = g2.dot(g2).item()

    if g12 < min(g11, g22):
        x = (g22 - g12) / (g11 + g22 - 2 * g12 + 1e-8)
    elif g11 < g22:
        x = 1
    else:
        x = 0

    g_mgd = x * g1 + (1 - x) * g2  # mgd gradient g_mgd
    return g_mgd


def moco(grads, y, lambd, beta, gamma, rho):
    # y update
    y = y - beta * (y - grads)

    # lambda update
    lambd = projection2simplex(lambd - gamma * (torch.transpose(y, 0, 1) @ (y @ lambd) + rho * lambd)).view([2, 1])

    g_moco = y @ lambd

    return g_moco.view(-1), y, lambd


def modo(grads1, grads2, lambd, gamma, rho):
    # lambda update
    lambd = projection2simplex(lambd - gamma * (torch.transpose(grads1, 0, 1) @ (grads2 @ lambd) + rho * lambd)).view(
        [2, 1])

    g_modo = 0.5 * (grads1 + grads2) @ lambd

    return g_modo.view(-1), lambd

# -----------------------------------------------------------------------------------------------------

def cagrad(grads, c=0.5):
    g1 = grads[:, 0]
    g2 = grads[:, 1]
    g0 = (g1 + g2) / 2

    g11 = g1.dot(g1).item()
    g12 = g1.dot(g2).item()
    g22 = g2.dot(g2).item()

    g0_norm = 0.5 * np.sqrt(g11 + g22 + 2 * g12 + 1e-4)

    # want to minimize g_w^Tg_0 + c*||g_0||*||g_w||
    coef = c * g0_norm

    def obj(x):
        # g_w^T g_0: x*0.5*(g11+g22-2g12)+(0.5+x)*(g12-g22)+g22
        # g_w^T g_w: x^2*(g11+g22-2g12)+2*x*(g12-g22)+g22
        return coef * np.sqrt(x ** 2 * (g11 + g22 - 2 * g12) + 2 * x * (g12 - g22) + g22 + 1e-4) + \
            0.5 * x * (g11 + g22 - 2 * g12) + (0.5 + x) * (g12 - g22) + g22

    res = minimize_scalar(obj, bounds=(0, 1), method='bounded')
    x = res.x

    gw = x * g1 + (1 - x) * g2
    gw_norm = np.sqrt(x ** 2 * g11 + (1 - x) ** 2 * g22 + 2 * x * (1 - x) * g12 + 1e-4)

    lmbda = coef / (gw_norm + 1e-4)
    g = g0 + lmbda * gw
    return g / (1 + c)

