import torch
from torch import nn


class RotoGradMagnitude(nn.Module):
    def __init__(self, num_tasks, update_at=20):
        super(RotoGradMagnitude, self).__init__()

        self.num_tasks = num_tasks
        self.update_at = update_at
        self.initial_grads = None
        self.counter = 0

    def forward(self, list_losses, grads, rep):
        grad_norms = [torch.norm(g, keepdim=True) for g in grads]

        if self.initial_grads is None or self.counter == self.update_at:
            self.initial_grads = grad_norms
        self.counter += 1

        conv_ratios = [x/torch.clamp(y, 1e-15) for x, y, in zip(grad_norms, self.initial_grads)]
        alphas = [x / torch.clamp(sum(conv_ratios), 1e-15) for x in conv_ratios]

        weighted_sum_norms = sum([a * g for a, g in zip(alphas, grad_norms)])
        grads = [g / n * weighted_sum_norms for g, n in zip(grads, grad_norms)]
        return grads
