import torch
import torch.nn as nn


class GradDrop(nn.Module):
    def __init__(self, num_tasks, leak_parameters=None):
        super(GradDrop, self).__init__()

        self.num_tasks = num_tasks
        self.leaks = leak_parameters if leak_parameters is not None else [0] * num_tasks

    def callback(self, losses, grads, input):  # losses are supposed to be already normalized (divided by initial loss)
        sign_grads = [1. for _ in range(self.num_tasks)]
        for i in range(self.num_tasks):
            sign_grads[i] = input.sign() * grads[i]
            if len(grads[0].size()) > 1:
                sign_grads[i] = grads[i].sum(dim=0, keepdim=True)

        odds = 0.5 * (1 + sum(grads) / sum(map(torch.abs, grads)))
        assert odds.size() == grads[0].size()

        new_grads = []
        samples = torch.rand(odds.size(), device=grads[0].device)
        for i in range(self.num_tasks):
            mask_i = (odds > samples) * (sign_grads[i] > 0) + (odds < samples) * (sign_grads[i] < 0)
            mask_i = self.leaks[i] + (1 - self.leaks[i]) * mask_i
            new_grads.append(mask_i * grads[i])

        return new_grads
