# PCGrad: https://proceedings.neurips.cc//paper_files/paper/2020/hash/3fe78a8acf5fda99de95303940a2420c-Abstract.html

import torch
import numpy as np
from utils import projection


def pcgrad(losses, grads, input):
    size = grads[0].size()
    num_tasks = len(grads)
    cos_sim_ij = grads[0].new_ones((num_tasks, num_tasks))
    grads_list = [g.flatten() for g in grads]

    # Precompute cosine similarity
    for i in range(num_tasks):
        for j in range(i + 1, num_tasks):
            cos_sim_ij[i, j] = torch.cosine_similarity(grads_list[i], grads_list[j], dim=0)
            cos_sim_ij[j, i] = cos_sim_ij[i, j]
        cos_sim_ij[i, i] = 1.

    # Randomly project gradients
    new_grads = []
    for i in np.random.permutation(num_tasks):
        grad_i = grads_list[i]
        for j in np.random.permutation(num_tasks):
            grad_j = grads_list[j]
            if cos_sim_ij[i, j] < 0:
                grad_i = grad_i - projection(grad_i, grad_j)
                assert (grads_list[i] != grad_i).any(), 'aliasing!'

        new_grads.append(grad_i.reshape(size))

    return new_grads
