

class DummyBalancer(object):
    def __init__(self, alpha=5):
        self.alpha = alpha
        self.losses = None

    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()

        losses = []
        hrepr = encoder(input)

        for id, decoder in decoders.items():
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest
            # backward gradient.
            encoder.zero_grad()
            loss = criteria[id](decoder(hrepr), targets[id])
            loss = self.alpha * loss if id == 'right' else loss  # if id == right it is orientation
            losses.append(loss)

            losses[-1].backward(retain_graph=True)

        self.losses = [loss.item() for loss in losses]
