import torch
from utils import unitary


def imtl_G(losses, grads, input):
    flatten_grads = [g.flatten() for g in grads]
    num_tasks = len(grads)
    if num_tasks == 1:
        return grads

    grad_diffs, unit_diffs = [], []
    for i in range(1, num_tasks):
        grad_diffs.append(flatten_grads[0] - flatten_grads[i])
        unit_diffs.append(unitary(flatten_grads[0]) - unitary(flatten_grads[i]))
    grad_diffs = torch.stack(grad_diffs, dim=0)
    unit_diffs = torch.stack(unit_diffs, dim=0)

    DU_T = torch.einsum('ik,jk->ij', grad_diffs, unit_diffs)
    DU_T_inv = torch.inverse(DU_T)

    alphas = torch.einsum('i,ki,kj->j', grads[0].flatten(), unit_diffs, DU_T_inv)
    alphas = torch.cat((1-alphas.sum(dim=0).unsqueeze(dim=0), alphas), dim=0) * num_tasks

    return [a * g for a, g in zip(alphas, grads)]
