import torch
import torch.nn as nn


class GradNorm(nn.Module):
    def __init__(self, num_tasks, alpha):
        super(GradNorm, self).__init__()

        self.weight_ = nn.ParameterList([nn.Parameter(torch.ones([]), requires_grad=True) for _ in range(num_tasks)])

        self.num_tasks = num_tasks
        self.alpha = alpha

    @property
    def weight(self):
        ws = [w.exp() + 1e-15 for w in self.weight_]
        norm_coef = self.num_tasks / sum(ws)
        return [w * norm_coef for w in ws]

    def callback(self, losses, grads, input):  # losses are supposed to be already normalized (divided by initial loss)
        grads_norm = [g.norm(p=2) for g in grads]

        mean_grad = sum([g*w for g, w in zip(grads, self.weight)]).detach().clone() / len(grads)
        mean_grad_norm = mean_grad.norm(p=2)
        mean_loss = sum(losses) / len(losses)

        for i, [loss, grad] in enumerate(zip(losses, grads_norm)):
            inverse_ratio_i = (loss / mean_loss) ** self.alpha
            mean_grad_i = mean_grad_norm * float(inverse_ratio_i)

            loss_gradnorm = torch.abs(grad * self.weight[i] - mean_grad_i)
            loss_gradnorm.backward()

        with torch.no_grad():
            new_grads = [g * w for g, w in zip(grads, self.weight)]  # scale them (they are already rotated)

        return new_grads
